Coverage for python/lsst/source/injection/inject_engine.py: 7%

226 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 04:26 -0700

1# This file is part of source_injection. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["generate_galsim_objects", "inject_galsim_objects_into_exposure"] 

25 

26import os 

27from collections import Counter 

28from collections.abc import Generator 

29from typing import Any 

30 

31import galsim 

32import numpy as np 

33import numpy.ma as ma 

34from astropy.io import fits 

35from astropy.table import Table 

36from galsim import GalSimFFTSizeError 

37from lsst.afw.geom import SkyWcs 

38from lsst.afw.image import ExposureF, PhotoCalib 

39from lsst.geom import Box2I, Point2D, Point2I, SpherePoint, arcseconds, degrees 

40from lsst.pex.exceptions import InvalidParameterError, LogicError 

41 

42 

43def get_object_data(source_data: dict[str, Any], object_class: galsim.GSObject) -> dict[str, Any]: 

44 """Assemble a dictionary of allowed keyword arguments and their 

45 corresponding values to use when constructing a GSObject. 

46 

47 Parameters 

48 ---------- 

49 source_data : `dict` [`str`, `Any`] 

50 Dictionary of source data. 

51 object_class : `galsim.gsobject.GSObject` 

52 Class of GSObject to match against. 

53 

54 Returns 

55 ------- 

56 object_data : `dict` [`str`, `Any`] 

57 Dictionary of source data to pass to the GSObject constructor. 

58 """ 

59 req = getattr(object_class, "_req_params", {}) 

60 opt = getattr(object_class, "_opt_params", {}) 

61 single = getattr(object_class, "_single_params", {}) 

62 

63 object_data = {} 

64 

65 # Check required args. 

66 for key in req: 

67 if key not in source_data: 

68 raise ValueError(f"Required parameter {key} not found in input catalog.") 

69 object_data[key] = source_data[key] 

70 

71 # Optional args. 

72 for key in opt: 

73 if key in source_data: 

74 object_data[key] = source_data[key] 

75 

76 # Single args. 

77 for s in single: 

78 count = 0 

79 for key in s: 

80 if key in source_data: 

81 count += 1 

82 if count > 1: 

83 raise ValueError(f"Only one of {s.keys()} allowed for type {object_class}.") 

84 object_data[key] = source_data[key] 

85 if count == 0: 

86 raise ValueError(f"One of the args {s.keys()} is required for type {object_class}.") 

87 

88 return object_data 

89 

90 

91def get_shear_data( 

92 source_data: dict[str, Any], 

93 shear_attributes: list[str] = [ 

94 "g1", 

95 "g2", 

96 "g", 

97 "e1", 

98 "e2", 

99 "e", 

100 "eta1", 

101 "eta2", 

102 "eta", 

103 "q", 

104 "beta", 

105 "shear", 

106 ], 

107) -> dict[str, Any]: 

108 """Assemble a dictionary of allowed keyword arguments and their 

109 corresponding values to use when constructing a Shear. 

110 

111 Parameters 

112 ---------- 

113 source_data : `dict` [`str`, `Any`] 

114 Dictionary of source data. 

115 shear_attributes : `list` [`str`], optional 

116 List of allowed shear attributes. 

117 

118 Returns 

119 ------- 

120 shear_data : `dict` [`str`, `Any`] 

121 Dictionary of source data to pass to the Shear constructor. 

122 """ 

123 shear_params = set(shear_attributes) & set(source_data.keys()) 

124 shear_data = {} 

125 for shear_param in shear_params: 

126 if shear_param == "beta": 

127 shear_data.update({shear_param: source_data[shear_param] * galsim.degrees}) 

128 else: 

129 shear_data.update({shear_param: source_data[shear_param]}) 

130 return shear_data 

131 

132 

133def generate_galsim_objects( 

134 injection_catalog: Table, 

135 photo_calib: PhotoCalib, 

136 wcs: SkyWcs, 

137 fits_alignment: str, 

138 stamp_prefix: str = "", 

139 logger: Any | None = None, 

140) -> Generator[tuple[SpherePoint, Point2D, int, galsim.gsobject.GSObject], None, None]: 

141 """Generate GalSim objects from an injection catalog. 

142 

143 Parameters 

144 ---------- 

145 injection_catalog : `astropy.table.Table` 

146 Table of sources to be injected. 

147 photo_calib : `lsst.afw.image.PhotoCalib` 

148 Photometric calibration used to calibrate injected sources. 

149 wcs : `lsst.afw.geom.SkyWcs` 

150 WCS used to calibrate injected sources. 

151 fits_alignment : `str` 

152 Alignment of the FITS image to the WCS. Allowed values: "wcs", "pixel". 

153 stamp_prefix : `str` 

154 Prefix to add to the stamp name. 

155 logger : `lsst.utils.logging.LsstLogAdapter`, optional 

156 Logger to use for logging messages. 

157 

158 Yields 

159 ------ 

160 sky_coords : `lsst.geom.SpherePoint` 

161 RA/Dec coordinates of the source. 

162 pixel_coords : `lsst.geom.Point2D` 

163 Pixel coordinates of the source. 

164 draw_size : `int` 

165 Size of the stamp to draw. 

166 object : `galsim.gsobject.GSObject` 

167 A fully specified and transformed GalSim object. 

168 """ 

169 if logger: 

170 source_types = Counter(injection_catalog["source_type"]) # type: ignore 

171 grammar0 = "source" if len(injection_catalog) == 1 else "sources" 

172 grammar1 = "type" if len(source_types) == 1 else "types" 

173 logger.info( 

174 "Generating %d injection %s consisting of %d unique %s: %s.", 

175 len(injection_catalog), 

176 grammar0, 

177 len(source_types), 

178 grammar1, 

179 ", ".join(f"{k}({v})" for k, v in source_types.items()), 

180 ) 

181 for source_data_full in injection_catalog: 

182 items = dict(source_data_full).items() # type: ignore 

183 source_data = {k: v for (k, v) in items if v is not ma.masked} 

184 try: 

185 sky_coords = SpherePoint(float(source_data["ra"]), float(source_data["dec"]), degrees) 

186 except KeyError: 

187 sky_coords = wcs.pixelToSky(float(source_data["x"]), float(source_data["y"])) 

188 try: 

189 pixel_coords = Point2D(source_data["x"], source_data["y"]) 

190 except KeyError: 

191 pixel_coords = wcs.skyToPixel(sky_coords) 

192 try: 

193 inst_flux = photo_calib.magnitudeToInstFlux(source_data["mag"], pixel_coords) 

194 except LogicError: 

195 continue 

196 try: 

197 draw_size = int(source_data["draw_size"]) 

198 except KeyError: 

199 draw_size = 0 

200 

201 if source_data["source_type"] == "Stamp": 

202 stamp_file = stamp_prefix + source_data["stamp"] 

203 object = make_galsim_stamp(stamp_file, fits_alignment, wcs, sky_coords, inst_flux) 

204 elif source_data["source_type"] == "Trail": 

205 object = make_galsim_trail(source_data, wcs, sky_coords, inst_flux) 

206 else: 

207 object = make_galsim_object(source_data, source_data["source_type"], inst_flux, logger) 

208 

209 yield sky_coords, pixel_coords, draw_size, object 

210 

211 

212def make_galsim_object( 

213 source_data: dict[str, Any], 

214 source_type: str, 

215 inst_flux: float, 

216 logger: Any | None = None, 

217) -> galsim.gsobject.GSObject: 

218 """Make a generic GalSim object from a collection of source data. 

219 

220 Parameters 

221 ---------- 

222 source_data : `dict` [`str`, `Any`] 

223 Dictionary of source data. 

224 source_type : `str` 

225 Type of the source, corresponding to a GalSim class. 

226 inst_flux : `float` 

227 Instrumental flux of the source. 

228 logger : `~lsst.utils.logging.LsstLogAdapter`, optional 

229 

230 Returns 

231 ------- 

232 object : `galsim.gsobject.GSObject` 

233 A fully specified and transformed GalSim object. 

234 """ 

235 # Populate the non-shear and non-flux parameters. 

236 object_class = getattr(galsim, source_type) 

237 object_data = get_object_data(source_data, object_class) 

238 object = object_class(**object_data) 

239 # Create a version of the object with an area-preserving shear applied. 

240 shear_data = get_shear_data(source_data) 

241 if shear_data: 

242 try: 

243 object = object.shear(**shear_data) 

244 except TypeError as err: 

245 if logger: 

246 logger.warning("Cannot apply shear to GalSim object: %s", err) 

247 pass 

248 # Apply the instrumental flux and return. 

249 object = object.withFlux(inst_flux) 

250 return object 

251 

252 

253def make_galsim_trail( 

254 source_data: dict[str, Any], 

255 wcs: SkyWcs, 

256 sky_coords: SpherePoint, 

257 inst_flux: float, 

258 trail_thickness: float = 1e-6, 

259) -> galsim.gsobject.GSObject: 

260 """Make a trail with GalSim from a collection of source data. 

261 

262 Parameters 

263 ---------- 

264 source_data : `dict` [`str`, `Any`] 

265 Dictionary of source data. 

266 wcs : `lsst.afw.geom.SkyWcs` 

267 World coordinate system. 

268 sky_coords : `lsst.geom.SpherePoint` 

269 Sky coordinates of the source. 

270 inst_flux : `float` 

271 Instrumental flux of the source. 

272 trail_thickness : `float` 

273 Thickness of the trail in pixels. 

274 

275 Returns 

276 ------- 

277 object : `galsim.gsobject.GSObject` 

278 A fully specified and transformed GalSim object. 

279 """ 

280 # Make a 'thin' box to mimic a line surface brightness profile of default 

281 # thickness = 1e-6 (i.e., much thinner than a pixel) 

282 object = galsim.Box(source_data["trail_length"], trail_thickness) 

283 try: 

284 object = object.rotate(source_data["beta"] * galsim.degrees) 

285 except KeyError: 

286 pass 

287 object = object.withFlux(inst_flux * source_data["trail_length"]) # type: ignore 

288 # GalSim objects are assumed to be in sky-coords. As we want the trail to 

289 # appear as defined above in image-coords, we must transform it here. 

290 linear_wcs = wcs.linearizePixelToSky(sky_coords, arcseconds) 

291 mat = linear_wcs.getMatrix() 

292 object = object.transform(mat[0, 0], mat[0, 1], mat[1, 0], mat[1, 1]) 

293 return object # type: ignore 

294 

295 

296def make_galsim_stamp( 

297 stamp_file: str, 

298 fits_alignment: str, 

299 wcs: SkyWcs, 

300 sky_coords: SpherePoint, 

301 inst_flux: float, 

302) -> galsim.gsobject.GSObject: 

303 """Make a postage stamp with GalSim from a FITS file. 

304 

305 Parameters 

306 ---------- 

307 stamp_file : `str` 

308 Path to the FITS file containing the postage stamp. 

309 fits_alignment : `str` 

310 Alignment of the FITS image to the WCS. Allowed values: "wcs", "pixel". 

311 wcs : `lsst.afw.geom.SkyWcs` 

312 World coordinate system. 

313 sky_coords : `lsst.geom.SpherePoint` 

314 Sky coordinates of the source. 

315 inst_flux : `float` 

316 Instrumental flux of the source. 

317 

318 Returns 

319 ------- 

320 object: `galsim.gsobject.GSObject` 

321 A fully specified and transformed GalSim object. 

322 """ 

323 stamp_file = stamp_file.strip() 

324 if os.path.exists(stamp_file): 

325 with fits.open(stamp_file) as hdul: 

326 hdu_images = [hdu.is_image and hdu.size > 0 for hdu in hdul] # type: ignore 

327 if any(hdu_images): 

328 stamp_data = galsim.fits.read(stamp_file, read_header=True, hdu=np.where(hdu_images)[0][0]) 

329 else: 

330 raise RuntimeError(f"Cannot find image in input FITS file {stamp_file}.") 

331 match fits_alignment: 

332 case "wcs": 

333 # galsim.fits.read will always attach a WCS to its output. 

334 # If it can't find a WCS in the FITS header, then it 

335 # defaults to scale = 1.0 arcsec / pix. If that's the scale 

336 # then we need to check if it was explicitly set or if it's 

337 # just the default. If it's just the default then we should 

338 # raise an exception. 

339 if is_wcs_galsim_default(stamp_data.wcs, stamp_data.header): # type: ignore 

340 raise RuntimeError(f"Cannot find WCS in input FITS file {stamp_file}") 

341 case "pixel": 

342 # We need to set stamp_data.wcs to the local target WCS. 

343 linear_wcs = wcs.linearizePixelToSky(sky_coords, arcseconds) 

344 mat = linear_wcs.getMatrix() 

345 stamp_data.wcs = galsim.JacobianWCS( # type: ignore 

346 mat[0, 0], mat[0, 1], mat[1, 0], mat[1, 1] 

347 ) 

348 object = galsim.InterpolatedImage(stamp_data, calculate_stepk=False) 

349 object = object.withFlux(inst_flux) 

350 return object # type: ignore 

351 else: 

352 raise RuntimeError(f"Cannot locate input FITS postage stamp {stamp_file}.") 

353 

354 

355def is_wcs_galsim_default( 

356 wcs: galsim.fitswcs.GSFitsWCS, 

357 hdr: galsim.fits.FitsHeader, 

358) -> bool: 

359 """Decide if wcs = galsim.PixelScale(1.0) is explicitly present in header, 

360 or if it's just the GalSim default. 

361 

362 Parameters 

363 ---------- 

364 wcs : galsim.fitswcs.GSFitsWCS 

365 Potentially default WCS. 

366 hdr : galsim.fits.FitsHeader 

367 Header as read in by GalSim. 

368 

369 Returns 

370 ------- 

371 is_default : bool 

372 True if default, False if explicitly set in header. 

373 """ 

374 if wcs != galsim.PixelScale(1.0): 

375 return False 

376 if hdr.get("GS_WCS") is not None: 

377 return False 

378 if hdr.get("CTYPE1", "LINEAR") == "LINEAR": 

379 return not any(k in hdr for k in ["CD1_1", "CDELT1"]) 

380 for wcs_type in galsim.fitswcs.fits_wcs_types: 

381 # If one of these succeeds, then assume result is explicit. 

382 try: 

383 wcs_type._readHeader(hdr) 

384 return False 

385 except Exception: 

386 pass 

387 else: 

388 return not any(k in hdr for k in ["CD1_1", "CDELT1"]) 

389 

390 

391def inject_galsim_objects_into_exposure( 

392 exposure: ExposureF, 

393 objects: Generator[tuple[SpherePoint, Point2D, int, galsim.gsobject.GSObject], None, None], 

394 mask_plane_name: str = "INJECTED", 

395 calib_flux_radius: float = 12.0, 

396 draw_size_max: int = 1000, 

397 logger: Any | None = None, 

398) -> tuple[list[int], list[galsim.BoundsI], list[bool], list[bool]]: 

399 """Inject sources into given exposure using GalSim. 

400 

401 Parameters 

402 ---------- 

403 exposure : `lsst.afw.image.ExposureF` 

404 The exposure to inject synthetic sources into. 

405 objects : `Generator` [`tuple`, None, None] 

406 An iterator of tuples that contains (or generates) locations and object 

407 surface brightness profiles to inject. The tuples should contain the 

408 following elements: `lsst.geom.SpherePoint`, `lsst.geom.Point2D`, 

409 `int`, `galsim.gsobject.GSObject`. 

410 mask_plane_name : `str` 

411 Name of the mask plane to use for the injected sources. 

412 calib_flux_radius : `float` 

413 Radius in pixels to use for the flux calibration. This is used to 

414 produce the correct instrumental fluxes within the radius. The value 

415 should match that of the field defined in slot_CalibFlux_instFlux. 

416 draw_size_max : `int` 

417 Maximum allowed size of the drawn object. If the object is larger than 

418 this, the draw size will be clipped to this size. 

419 logger : `lsst.utils.logging.LsstLogAdapter`, optional 

420 Logger to use for logging messages. 

421 

422 Returns 

423 ------- 

424 draw_sizes : `list` [`int`] 

425 Draw sizes of the injected sources. 

426 common_bounds : `list` [`galsim.BoundsI`] 

427 Common bounds of the drawn objects. 

428 fft_size_errors : `list` [`bool`] 

429 Boolean flags indicating whether a GalSimFFTSizeError was raised. 

430 psf_compute_errors : `list` [`bool`] 

431 Boolean flags indicating whether a PSF computation error was raised. 

432 """ 

433 exposure.mask.addMaskPlane(mask_plane_name) 

434 mask_plane_core_name = mask_plane_name + "_CORE" 

435 exposure.mask.addMaskPlane(mask_plane_core_name) 

436 if logger: 

437 logger.info( 

438 "Adding %s and %s mask planes to the exposure.", 

439 mask_plane_name, 

440 mask_plane_core_name, 

441 ) 

442 psf = exposure.getPsf() 

443 wcs = exposure.getWcs() 

444 bbox = exposure.getBBox() 

445 full_bounds = galsim.BoundsI(bbox.minX, bbox.maxX, bbox.minY, bbox.maxY) 

446 galsim_image = galsim.Image(exposure.image.array, bounds=full_bounds) 

447 pixel_scale = wcs.getPixelScale(bbox.getCenter()).asArcseconds() 

448 

449 draw_sizes: list[int] = [] 

450 common_bounds: list[galsim.BoundsI] = [] 

451 fft_size_errors: list[bool] = [] 

452 psf_compute_errors: list[bool] = [] 

453 for i, (sky_coords, pixel_coords, draw_size, object) in enumerate(objects): 

454 # Instantiate default returns in case of early exit from this loop. 

455 draw_sizes.append(0) 

456 common_bounds.append(galsim.BoundsI()) 

457 fft_size_errors.append(False) 

458 psf_compute_errors.append(False) 

459 

460 # Get spatial coordinates and WCS. 

461 posd = galsim.PositionD(pixel_coords.x, pixel_coords.y) 

462 posi = galsim.PositionI(pixel_coords.x // 1, pixel_coords.y // 1) 

463 if logger: 

464 logger.debug(f"Injecting synthetic source at {pixel_coords}.") 

465 mat = wcs.linearizePixelToSky(sky_coords, arcseconds).getMatrix() 

466 galsim_wcs = galsim.JacobianWCS(mat[0, 0], mat[0, 1], mat[1, 0], mat[1, 1]) 

467 

468 # This check is here because sometimes the WCS is multivalued and 

469 # objects that should not be included were being included. 

470 galsim_pixel_scale = np.sqrt(galsim_wcs.pixelArea()) 

471 if galsim_pixel_scale < pixel_scale / 2 or galsim_pixel_scale > pixel_scale * 2: 

472 continue 

473 

474 # Compute the PSF at the object location. 

475 try: 

476 psf_array = psf.computeKernelImage(pixel_coords).array 

477 except InvalidParameterError: 

478 # Try mapping to nearest point contained in bbox. 

479 contained_point = Point2D( 

480 np.clip(pixel_coords.x, bbox.minX, bbox.maxX), np.clip(pixel_coords.y, bbox.minY, bbox.maxY) 

481 ) 

482 if pixel_coords == contained_point: # no difference, so skip immediately 

483 psf_compute_errors[i] = True 

484 if logger: 

485 logger.debug("Cannot compute PSF for object at %s; flagging and skipping.", sky_coords) 

486 continue 

487 # Otherwise, try again with new point. 

488 try: 

489 psf_array = psf.computeKernelImage(contained_point).array 

490 except InvalidParameterError: 

491 psf_compute_errors[i] = True 

492 if logger: 

493 logger.debug("Cannot compute PSF for object at %s; flagging and skipping.", sky_coords) 

494 continue 

495 

496 # Compute the aperture corrected PSF interpolated image. 

497 aperture_correction = psf.computeApertureFlux(calib_flux_radius, psf.getAveragePosition()) 

498 psf_array /= aperture_correction 

499 galsim_psf = galsim.InterpolatedImage(galsim.Image(psf_array), wcs=galsim_wcs) 

500 

501 # Convolve the object with the PSF and generate draw size. 

502 conv = galsim.Convolve(object, galsim_psf) 

503 if draw_size == 0: 

504 draw_size = conv.getGoodImageSize(galsim_wcs.minLinearScale()) # type: ignore 

505 injection_draw_size = int(draw_size) 

506 injection_core_size = 3 

507 if draw_size_max > 0 and injection_draw_size > draw_size_max: 

508 if logger: 

509 logger.warning( 

510 "Clipping draw size for object at %s from %d to %d pixels.", 

511 sky_coords, 

512 injection_draw_size, 

513 draw_size_max, 

514 ) 

515 injection_draw_size = draw_size_max 

516 draw_sizes[i] = injection_draw_size 

517 if injection_core_size > injection_draw_size: 

518 if logger: 

519 logger.debug( 

520 "Clipping core size for object at %s from %d to %d pixels.", 

521 sky_coords, 

522 injection_core_size, 

523 injection_draw_size, 

524 ) 

525 injection_core_size = injection_draw_size 

526 sub_bounds = galsim.BoundsI(posi).withBorder(injection_draw_size // 2) 

527 object_common_bounds = full_bounds & sub_bounds 

528 common_bounds[i] = object_common_bounds # type: ignore 

529 

530 # Inject the source if there is any overlap. 

531 if object_common_bounds.area() > 0: 

532 common_image = galsim_image[object_common_bounds] 

533 offset = posd - object_common_bounds.true_center 

534 # Note, for calexp injection, pixel is already part of the PSF and 

535 # for coadd injection, it's incorrect to include the output pixel. 

536 # So for both cases, we draw using method='no_pixel'. 

537 try: 

538 conv.drawImage( 

539 common_image, add_to_image=True, offset=offset, wcs=galsim_wcs, method="no_pixel" 

540 ) 

541 except GalSimFFTSizeError as err: 

542 fft_size_errors[i] = True 

543 if logger: 

544 logger.debug( 

545 "GalSimFFTSizeError raised for object at %s; flagging and skipping.\n%s", 

546 sky_coords, 

547 err, 

548 ) 

549 continue 

550 common_box = Box2I( 

551 Point2I(object_common_bounds.xmin, object_common_bounds.ymin), 

552 Point2I(object_common_bounds.xmax, object_common_bounds.ymax), 

553 ) 

554 bitvalue = exposure.mask.getPlaneBitMask(mask_plane_name) 

555 exposure[common_box].mask.array |= bitvalue 

556 # Add a 3 x 3 pixel mask centered on the object. The mask must be 

557 # large enough to always identify the core/peak of the injected 

558 # source, but small enough that it rarely overlaps real sources. 

559 sub_bounds_core = galsim.BoundsI(posi).withBorder(injection_core_size // 2) 

560 object_common_bounds_core = full_bounds & sub_bounds_core 

561 if object_common_bounds_core.area() > 0: 

562 common_box_core = Box2I( 

563 Point2I(object_common_bounds_core.xmin, object_common_bounds_core.ymin), 

564 Point2I(object_common_bounds_core.xmax, object_common_bounds_core.ymax), 

565 ) 

566 bitvalue_core = exposure.mask.getPlaneBitMask(mask_plane_core_name) 

567 exposure[common_box_core].mask.array |= bitvalue_core 

568 else: 

569 if logger: 

570 logger.debug("No area overlap for object at %s; flagging and skipping.", sky_coords) 

571 

572 return draw_sizes, common_bounds, fft_size_errors, psf_compute_errors