Coverage for python/lsst/source/injection/inject_engine.py: 8%
223 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-27 04:41 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-27 04:41 -0700
1# This file is part of source_injection.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["generate_galsim_objects", "inject_galsim_objects_into_exposure"]
26import os
27from collections import Counter
28from collections.abc import Generator
29from typing import Any
31import galsim
32import numpy as np
33import numpy.ma as ma
34from astropy.io import fits
35from astropy.table import Table
36from galsim import GalSimFFTSizeError
37from lsst.afw.geom import SkyWcs
38from lsst.afw.image import ExposureF, PhotoCalib
39from lsst.geom import Box2I, Point2D, Point2I, SpherePoint, arcseconds, degrees
40from lsst.pex.exceptions import InvalidParameterError, LogicError
43def get_object_data(source_data: dict[str, Any], object_class: galsim.GSObject) -> dict[str, Any]:
44 """Assemble a dictionary of allowed keyword arguments and their
45 corresponding values to use when constructing a GSObject.
47 Parameters
48 ----------
49 source_data : `dict` [`str`, `Any`]
50 Dictionary of source data.
51 object_class : `galsim.gsobject.GSObject`
52 Class of GSObject to match against.
54 Returns
55 -------
56 object_data : `dict` [`str`, `Any`]
57 Dictionary of source data to pass to the GSObject constructor.
58 """
59 req = getattr(object_class, "_req_params", {})
60 opt = getattr(object_class, "_opt_params", {})
61 single = getattr(object_class, "_single_params", {})
63 object_data = {}
65 # Check required args.
66 for key in req:
67 if key not in source_data:
68 raise ValueError(f"Required parameter {key} not found in input catalog.")
69 object_data[key] = source_data[key]
71 # Optional args.
72 for key in opt:
73 if key in source_data:
74 object_data[key] = source_data[key]
76 # Single args.
77 for s in single:
78 count = 0
79 for key in s:
80 if key in source_data:
81 count += 1
82 if count > 1:
83 raise ValueError(f"Only one of {s.keys()} allowed for type {object_class}.")
84 object_data[key] = source_data[key]
85 if count == 0:
86 raise ValueError(f"One of the args {s.keys()} is required for type {object_class}.")
88 return object_data
91def get_shear_data(
92 source_data: dict[str, Any],
93 shear_attributes: list[str] = [
94 "g1",
95 "g2",
96 "g",
97 "e1",
98 "e2",
99 "e",
100 "eta1",
101 "eta2",
102 "eta",
103 "q",
104 "beta",
105 "shear",
106 ],
107) -> dict[str, Any]:
108 """Assemble a dictionary of allowed keyword arguments and their
109 corresponding values to use when constructing a Shear.
111 Parameters
112 ----------
113 source_data : `dict` [`str`, `Any`]
114 Dictionary of source data.
115 shear_attributes : `list` [`str`], optional
116 List of allowed shear attributes.
118 Returns
119 -------
120 shear_data : `dict` [`str`, `Any`]
121 Dictionary of source data to pass to the Shear constructor.
122 """
123 shear_params = set(shear_attributes) & set(source_data.keys())
124 shear_data = {}
125 for shear_param in shear_params:
126 if shear_param == "beta":
127 shear_data.update({shear_param: source_data[shear_param] * galsim.degrees})
128 else:
129 shear_data.update({shear_param: source_data[shear_param]})
130 return shear_data
133def generate_galsim_objects(
134 injection_catalog: Table,
135 photo_calib: PhotoCalib,
136 wcs: SkyWcs,
137 fits_alignment: str,
138 stamp_prefix: str = "",
139 logger: Any | None = None,
140) -> Generator[tuple[SpherePoint, Point2D, int, galsim.gsobject.GSObject], None, None]:
141 """Generate GalSim objects from an injection catalog.
143 Parameters
144 ----------
145 injection_catalog : `astropy.table.Table`
146 Table of sources to be injected.
147 photo_calib : `lsst.afw.image.PhotoCalib`
148 Photometric calibration used to calibrate injected sources.
149 wcs : `lsst.afw.geom.SkyWcs`
150 WCS used to calibrate injected sources.
151 fits_alignment : `str`
152 Alignment of the FITS image to the WCS. Allowed values: "wcs", "pixel".
153 stamp_prefix : `str`
154 Prefix to add to the stamp name.
155 logger : `lsst.utils.logging.LsstLogAdapter`, optional
156 Logger to use for logging messages.
158 Yields
159 ------
160 sky_coords : `lsst.geom.SpherePoint`
161 RA/Dec coordinates of the source.
162 pixel_coords : `lsst.geom.Point2D`
163 Pixel coordinates of the source.
164 draw_size : `int`
165 Size of the stamp to draw.
166 object : `galsim.gsobject.GSObject`
167 A fully specified and transformed GalSim object.
168 """
169 if logger:
170 source_types = Counter(injection_catalog["source_type"]) # type: ignore
171 grammar0 = "source" if len(injection_catalog) == 1 else "sources"
172 grammar1 = "type" if len(source_types) == 1 else "types"
173 logger.info(
174 "Generating %d injection %s consisting of %d unique %s: %s.",
175 len(injection_catalog),
176 grammar0,
177 len(source_types),
178 grammar1,
179 ", ".join(f"{k}({v})" for k, v in source_types.items()),
180 )
181 for source_data_full in injection_catalog:
182 items = dict(source_data_full).items() # type: ignore
183 source_data = {k: v for (k, v) in items if v is not ma.masked}
184 try:
185 sky_coords = SpherePoint(float(source_data["ra"]), float(source_data["dec"]), degrees)
186 except KeyError:
187 sky_coords = wcs.pixelToSky(float(source_data["x"]), float(source_data["y"]))
188 try:
189 pixel_coords = Point2D(source_data["x"], source_data["y"])
190 except KeyError:
191 pixel_coords = wcs.skyToPixel(sky_coords)
192 try:
193 inst_flux = photo_calib.magnitudeToInstFlux(source_data["mag"], pixel_coords)
194 except LogicError:
195 continue
196 try:
197 draw_size = int(source_data["draw_size"])
198 except KeyError:
199 draw_size = 0
201 if source_data["source_type"] == "Stamp":
202 stamp_file = stamp_prefix + source_data["stamp"]
203 object = make_galsim_stamp(stamp_file, fits_alignment, wcs, sky_coords, inst_flux)
204 elif source_data["source_type"] == "Trail":
205 object = make_galsim_trail(source_data, wcs, sky_coords, inst_flux)
206 else:
207 object = make_galsim_object(source_data, source_data["source_type"], inst_flux)
209 yield sky_coords, pixel_coords, draw_size, object
212def make_galsim_object(
213 source_data: dict[str, Any],
214 source_type: str,
215 inst_flux: float,
216) -> galsim.gsobject.GSObject:
217 """Make a generic GalSim object from a collection of source data.
219 Parameters
220 ----------
221 source_data : `dict` [`str`, `Any`]
222 Dictionary of source data.
223 source_type : `str`
224 Type of the source, corresponding to a GalSim class.
225 inst_flux : `float`
226 Instrumental flux of the source.
228 Returns
229 -------
230 object : `galsim.gsobject.GSObject`
231 A fully specified and transformed GalSim object.
232 """
233 # Populate the non-shear and non-flux parameters.
234 object_class = getattr(galsim, source_type)
235 object_data = get_object_data(source_data, object_class)
236 object = object_class(**object_data)
237 # Create a version of the object with an area-preserving shear applied.
238 shear_data = get_shear_data(source_data)
239 try:
240 object = object.shear(**shear_data)
241 except TypeError:
242 pass
243 # Apply the instrumental flux and return.
244 object = object.withFlux(inst_flux)
245 return object
248def make_galsim_trail(
249 source_data: dict[str, Any],
250 wcs: SkyWcs,
251 sky_coords: SpherePoint,
252 inst_flux: float,
253 trail_thickness: float = 1e-6,
254) -> galsim.gsobject.GSObject:
255 """Make a trail with GalSim from a collection of source data.
257 Parameters
258 ----------
259 source_data : `dict` [`str`, `Any`]
260 Dictionary of source data.
261 wcs : `lsst.afw.geom.SkyWcs`
262 World coordinate system.
263 sky_coords : `lsst.geom.SpherePoint`
264 Sky coordinates of the source.
265 inst_flux : `float`
266 Instrumental flux of the source.
267 trail_thickness : `float`
268 Thickness of the trail in pixels.
270 Returns
271 -------
272 object : `galsim.gsobject.GSObject`
273 A fully specified and transformed GalSim object.
274 """
275 # Make a 'thin' box to mimic a line surface brightness profile of default
276 # thickness = 1e-6 (i.e., much thinner than a pixel)
277 object = galsim.Box(source_data["trail_length"], trail_thickness)
278 try:
279 object = object.rotate(source_data["beta"] * galsim.degrees)
280 except KeyError:
281 pass
282 object = object.withFlux(inst_flux * source_data["trail_length"]) # type: ignore
283 # GalSim objects are assumed to be in sky-coords. As we want the trail to
284 # appear as defined above in image-coords, we must transform it here.
285 linear_wcs = wcs.linearizePixelToSky(sky_coords, arcseconds)
286 mat = linear_wcs.getMatrix()
287 object = object.transform(mat[0, 0], mat[0, 1], mat[1, 0], mat[1, 1])
288 return object # type: ignore
291def make_galsim_stamp(
292 stamp_file: str,
293 fits_alignment: str,
294 wcs: SkyWcs,
295 sky_coords: SpherePoint,
296 inst_flux: float,
297) -> galsim.gsobject.GSObject:
298 """Make a postage stamp with GalSim from a FITS file.
300 Parameters
301 ----------
302 stamp_file : `str`
303 Path to the FITS file containing the postage stamp.
304 fits_alignment : `str`
305 Alignment of the FITS image to the WCS. Allowed values: "wcs", "pixel".
306 wcs : `lsst.afw.geom.SkyWcs`
307 World coordinate system.
308 sky_coords : `lsst.geom.SpherePoint`
309 Sky coordinates of the source.
310 inst_flux : `float`
311 Instrumental flux of the source.
313 Returns
314 -------
315 object: `galsim.gsobject.GSObject`
316 A fully specified and transformed GalSim object.
317 """
318 stamp_file = stamp_file.strip()
319 if os.path.exists(stamp_file):
320 with fits.open(stamp_file) as hdul:
321 hdu_images = [hdu.is_image and hdu.size > 0 for hdu in hdul] # type: ignore
322 if any(hdu_images):
323 stamp_data = galsim.fits.read(stamp_file, read_header=True, hdu=np.where(hdu_images)[0][0])
324 else:
325 raise RuntimeError(f"Cannot find image in input FITS file {stamp_file}.")
326 match fits_alignment:
327 case "wcs":
328 # galsim.fits.read will always attach a WCS to its output.
329 # If it can't find a WCS in the FITS header, then it
330 # defaults to scale = 1.0 arcsec / pix. If that's the scale
331 # then we need to check if it was explicitly set or if it's
332 # just the default. If it's just the default then we should
333 # raise an exception.
334 if is_wcs_galsim_default(stamp_data.wcs, stamp_data.header): # type: ignore
335 raise RuntimeError(f"Cannot find WCS in input FITS file {stamp_file}")
336 case "pixel":
337 # We need to set stamp_data.wcs to the local target WCS.
338 linear_wcs = wcs.linearizePixelToSky(sky_coords, arcseconds)
339 mat = linear_wcs.getMatrix()
340 stamp_data.wcs = galsim.JacobianWCS( # type: ignore
341 mat[0, 0], mat[0, 1], mat[1, 0], mat[1, 1]
342 )
343 object = galsim.InterpolatedImage(stamp_data, calculate_stepk=False)
344 object = object.withFlux(inst_flux)
345 return object # type: ignore
346 else:
347 raise RuntimeError(f"Cannot locate input FITS postage stamp {stamp_file}.")
350def is_wcs_galsim_default(
351 wcs: galsim.fitswcs.GSFitsWCS,
352 hdr: galsim.fits.FitsHeader,
353) -> bool:
354 """Decide if wcs = galsim.PixelScale(1.0) is explicitly present in header,
355 or if it's just the GalSim default.
357 Parameters
358 ----------
359 wcs : galsim.fitswcs.GSFitsWCS
360 Potentially default WCS.
361 hdr : galsim.fits.FitsHeader
362 Header as read in by GalSim.
364 Returns
365 -------
366 is_default : bool
367 True if default, False if explicitly set in header.
368 """
369 if wcs != galsim.PixelScale(1.0):
370 return False
371 if hdr.get("GS_WCS") is not None:
372 return False
373 if hdr.get("CTYPE1", "LINEAR") == "LINEAR":
374 return not any(k in hdr for k in ["CD1_1", "CDELT1"])
375 for wcs_type in galsim.fitswcs.fits_wcs_types:
376 # If one of these succeeds, then assume result is explicit.
377 try:
378 wcs_type._readHeader(hdr)
379 return False
380 except Exception:
381 pass
382 else:
383 return not any(k in hdr for k in ["CD1_1", "CDELT1"])
386def inject_galsim_objects_into_exposure(
387 exposure: ExposureF,
388 objects: Generator[tuple[SpherePoint, Point2D, int, galsim.gsobject.GSObject], None, None],
389 mask_plane_name: str = "INJECTED",
390 calib_flux_radius: float = 12.0,
391 draw_size_max: int = 1000,
392 logger: Any | None = None,
393) -> tuple[list[int], list[galsim.BoundsI], list[bool], list[bool]]:
394 """Inject sources into given exposure using GalSim.
396 Parameters
397 ----------
398 exposure : `lsst.afw.image.ExposureF`
399 The exposure to inject synthetic sources into.
400 objects : `Generator` [`tuple`, None, None]
401 An iterator of tuples that contains (or generates) locations and object
402 surface brightness profiles to inject. The tuples should contain the
403 following elements: `lsst.geom.SpherePoint`, `lsst.geom.Point2D`,
404 `int`, `galsim.gsobject.GSObject`.
405 mask_plane_name : `str`
406 Name of the mask plane to use for the injected sources.
407 calib_flux_radius : `float`
408 Radius in pixels to use for the flux calibration. This is used to
409 produce the correct instrumental fluxes within the radius. The value
410 should match that of the field defined in slot_CalibFlux_instFlux.
411 draw_size_max : `int`
412 Maximum allowed size of the drawn object. If the object is larger than
413 this, the draw size will be clipped to this size.
414 logger : `lsst.utils.logging.LsstLogAdapter`, optional
415 Logger to use for logging messages.
417 Returns
418 -------
419 draw_sizes : `list` [`int`]
420 Draw sizes of the injected sources.
421 common_bounds : `list` [`galsim.BoundsI`]
422 Common bounds of the drawn objects.
423 fft_size_errors : `list` [`bool`]
424 Boolean flags indicating whether a GalSimFFTSizeError was raised.
425 psf_compute_errors : `list` [`bool`]
426 Boolean flags indicating whether a PSF computation error was raised.
427 """
428 exposure.mask.addMaskPlane(mask_plane_name)
429 mask_plane_core_name = mask_plane_name + "_CORE"
430 exposure.mask.addMaskPlane(mask_plane_core_name)
431 if logger:
432 logger.info(
433 "Adding %s and %s mask planes to the exposure.",
434 mask_plane_name,
435 mask_plane_core_name,
436 )
437 psf = exposure.getPsf()
438 wcs = exposure.getWcs()
439 bbox = exposure.getBBox()
440 full_bounds = galsim.BoundsI(bbox.minX, bbox.maxX, bbox.minY, bbox.maxY)
441 galsim_image = galsim.Image(exposure.image.array, bounds=full_bounds)
442 pixel_scale = wcs.getPixelScale(bbox.getCenter()).asArcseconds()
444 draw_sizes: list[int] = []
445 common_bounds: list[galsim.BoundsI] = []
446 fft_size_errors: list[bool] = []
447 psf_compute_errors: list[bool] = []
448 for i, (sky_coords, pixel_coords, draw_size, object) in enumerate(objects):
449 # Instantiate default returns in case of early exit from this loop.
450 draw_sizes.append(0)
451 common_bounds.append(galsim.BoundsI())
452 fft_size_errors.append(False)
453 psf_compute_errors.append(False)
455 # Get spatial coordinates and WCS.
456 posd = galsim.PositionD(pixel_coords.x, pixel_coords.y)
457 posi = galsim.PositionI(pixel_coords.x // 1, pixel_coords.y // 1)
458 if logger:
459 logger.debug(f"Injecting synthetic source at {pixel_coords}.")
460 mat = wcs.linearizePixelToSky(sky_coords, arcseconds).getMatrix()
461 galsim_wcs = galsim.JacobianWCS(mat[0, 0], mat[0, 1], mat[1, 0], mat[1, 1])
463 # This check is here because sometimes the WCS is multivalued and
464 # objects that should not be included were being included.
465 galsim_pixel_scale = np.sqrt(galsim_wcs.pixelArea())
466 if galsim_pixel_scale < pixel_scale / 2 or galsim_pixel_scale > pixel_scale * 2:
467 continue
469 # Compute the PSF at the object location.
470 try:
471 psf_array = psf.computeKernelImage(pixel_coords).array
472 except InvalidParameterError:
473 # Try mapping to nearest point contained in bbox.
474 contained_point = Point2D(
475 np.clip(pixel_coords.x, bbox.minX, bbox.maxX), np.clip(pixel_coords.y, bbox.minY, bbox.maxY)
476 )
477 if pixel_coords == contained_point: # no difference, so skip immediately
478 psf_compute_errors[i] = True
479 if logger:
480 logger.debug("Cannot compute PSF for object at %s; flagging and skipping.", sky_coords)
481 continue
482 # Otherwise, try again with new point.
483 try:
484 psf_array = psf.computeKernelImage(contained_point).array
485 except InvalidParameterError:
486 psf_compute_errors[i] = True
487 if logger:
488 logger.debug("Cannot compute PSF for object at %s; flagging and skipping.", sky_coords)
489 continue
491 # Compute the aperture corrected PSF interpolated image.
492 aperture_correction = psf.computeApertureFlux(calib_flux_radius, psf.getAveragePosition())
493 psf_array /= aperture_correction
494 galsim_psf = galsim.InterpolatedImage(galsim.Image(psf_array), wcs=galsim_wcs)
496 # Convolve the object with the PSF and generate draw size.
497 conv = galsim.Convolve(object, galsim_psf)
498 if draw_size == 0:
499 draw_size = conv.getGoodImageSize(galsim_wcs.minLinearScale()) # type: ignore
500 injection_draw_size = int(draw_size)
501 injection_core_size = 3
502 if draw_size_max > 0 and injection_draw_size > draw_size_max:
503 if logger:
504 logger.warning(
505 "Clipping draw size for object at %s from %d to %d pixels.",
506 sky_coords,
507 injection_draw_size,
508 draw_size_max,
509 )
510 injection_draw_size = draw_size_max
511 draw_sizes[i] = injection_draw_size
512 if injection_core_size > injection_draw_size:
513 if logger:
514 logger.debug(
515 "Clipping core size for object at %s from %d to %d pixels.",
516 sky_coords,
517 injection_core_size,
518 injection_draw_size,
519 )
520 injection_core_size = injection_draw_size
521 sub_bounds = galsim.BoundsI(posi).withBorder(injection_draw_size // 2)
522 object_common_bounds = full_bounds & sub_bounds
523 common_bounds[i] = object_common_bounds # type: ignore
525 # Inject the source if there is any overlap.
526 if object_common_bounds.area() > 0:
527 common_image = galsim_image[object_common_bounds]
528 offset = posd - object_common_bounds.true_center
529 # Note, for calexp injection, pixel is already part of the PSF and
530 # for coadd injection, it's incorrect to include the output pixel.
531 # So for both cases, we draw using method='no_pixel'.
532 try:
533 conv.drawImage(
534 common_image, add_to_image=True, offset=offset, wcs=galsim_wcs, method="no_pixel"
535 )
536 except GalSimFFTSizeError as err:
537 fft_size_errors[i] = True
538 if logger:
539 logger.debug(
540 "GalSimFFTSizeError raised for object at %s; flagging and skipping.\n%s",
541 sky_coords,
542 err,
543 )
544 continue
545 common_box = Box2I(
546 Point2I(object_common_bounds.xmin, object_common_bounds.ymin),
547 Point2I(object_common_bounds.xmax, object_common_bounds.ymax),
548 )
549 bitvalue = exposure.mask.getPlaneBitMask(mask_plane_name)
550 exposure[common_box].mask.array |= bitvalue
551 # Add a 3 x 3 pixel mask centered on the object. The mask must be
552 # large enough to always identify the core/peak of the injected
553 # source, but small enough that it rarely overlaps real sources.
554 sub_bounds_core = galsim.BoundsI(posi).withBorder(injection_core_size // 2)
555 object_common_bounds_core = full_bounds & sub_bounds_core
556 if object_common_bounds_core.area() > 0:
557 common_box_core = Box2I(
558 Point2I(object_common_bounds_core.xmin, object_common_bounds_core.ymin),
559 Point2I(object_common_bounds_core.xmax, object_common_bounds_core.ymax),
560 )
561 bitvalue_core = exposure.mask.getPlaneBitMask(mask_plane_core_name)
562 exposure[common_box_core].mask.array |= bitvalue_core
563 else:
564 if logger:
565 logger.debug("No area overlap for object at %s; flagging and skipping.", sky_coords)
567 return draw_sizes, common_bounds, fft_size_errors, psf_compute_errors