Coverage for python/lsst/source/injection/inject_engine.py: 8%

223 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-19 05:51 -0700

1# This file is part of source_injection. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["generate_galsim_objects", "inject_galsim_objects_into_exposure"] 

25 

26import os 

27from collections import Counter 

28from collections.abc import Generator 

29from typing import Any 

30 

31import galsim 

32import numpy as np 

33import numpy.ma as ma 

34from astropy.io import fits 

35from astropy.table import Table 

36from galsim import GalSimFFTSizeError 

37from lsst.afw.geom import SkyWcs 

38from lsst.afw.image import ExposureF, PhotoCalib 

39from lsst.geom import Box2I, Point2D, Point2I, SpherePoint, arcseconds, degrees 

40from lsst.pex.exceptions import InvalidParameterError, LogicError 

41 

42 

43def get_object_data(source_data: dict[str, Any], object_class: galsim.GSObject) -> dict[str, Any]: 

44 """Assemble a dictionary of allowed keyword arguments and their 

45 corresponding values to use when constructing a GSObject. 

46 

47 Parameters 

48 ---------- 

49 source_data : `dict` [`str`, `Any`] 

50 Dictionary of source data. 

51 object_class : `galsim.gsobject.GSObject` 

52 Class of GSObject to match against. 

53 

54 Returns 

55 ------- 

56 object_data : `dict` [`str`, `Any`] 

57 Dictionary of source data to pass to the GSObject constructor. 

58 """ 

59 req = getattr(object_class, "_req_params", {}) 

60 opt = getattr(object_class, "_opt_params", {}) 

61 single = getattr(object_class, "_single_params", {}) 

62 

63 object_data = {} 

64 

65 # Check required args. 

66 for key in req: 

67 if key not in source_data: 

68 raise ValueError(f"Required parameter {key} not found in input catalog.") 

69 object_data[key] = source_data[key] 

70 

71 # Optional args. 

72 for key in opt: 

73 if key in source_data: 

74 object_data[key] = source_data[key] 

75 

76 # Single args. 

77 for s in single: 

78 count = 0 

79 for key in s: 

80 if key in source_data: 

81 count += 1 

82 if count > 1: 

83 raise ValueError(f"Only one of {s.keys()} allowed for type {object_class}.") 

84 object_data[key] = source_data[key] 

85 if count == 0: 

86 raise ValueError(f"One of the args {s.keys()} is required for type {object_class}.") 

87 

88 return object_data 

89 

90 

91def get_shear_data( 

92 source_data: dict[str, Any], 

93 shear_attributes: list[str] = [ 

94 "g1", 

95 "g2", 

96 "g", 

97 "e1", 

98 "e2", 

99 "e", 

100 "eta1", 

101 "eta2", 

102 "eta", 

103 "q", 

104 "beta", 

105 "shear", 

106 ], 

107) -> dict[str, Any]: 

108 """Assemble a dictionary of allowed keyword arguments and their 

109 corresponding values to use when constructing a Shear. 

110 

111 Parameters 

112 ---------- 

113 source_data : `dict` [`str`, `Any`] 

114 Dictionary of source data. 

115 shear_attributes : `list` [`str`], optional 

116 List of allowed shear attributes. 

117 

118 Returns 

119 ------- 

120 shear_data : `dict` [`str`, `Any`] 

121 Dictionary of source data to pass to the Shear constructor. 

122 """ 

123 shear_params = set(shear_attributes) & set(source_data.keys()) 

124 shear_data = {} 

125 for shear_param in shear_params: 

126 if shear_param == "beta": 

127 shear_data.update({shear_param: source_data[shear_param] * galsim.degrees}) 

128 else: 

129 shear_data.update({shear_param: source_data[shear_param]}) 

130 return shear_data 

131 

132 

133def generate_galsim_objects( 

134 injection_catalog: Table, 

135 photo_calib: PhotoCalib, 

136 wcs: SkyWcs, 

137 fits_alignment: str, 

138 stamp_prefix: str = "", 

139 logger: Any | None = None, 

140) -> Generator[tuple[SpherePoint, Point2D, int, galsim.gsobject.GSObject], None, None]: 

141 """Generate GalSim objects from an injection catalog. 

142 

143 Parameters 

144 ---------- 

145 injection_catalog : `astropy.table.Table` 

146 Table of sources to be injected. 

147 photo_calib : `lsst.afw.image.PhotoCalib` 

148 Photometric calibration used to calibrate injected sources. 

149 wcs : `lsst.afw.geom.SkyWcs` 

150 WCS used to calibrate injected sources. 

151 fits_alignment : `str` 

152 Alignment of the FITS image to the WCS. Allowed values: "wcs", "pixel". 

153 stamp_prefix : `str` 

154 Prefix to add to the stamp name. 

155 logger : `lsst.utils.logging.LsstLogAdapter`, optional 

156 Logger to use for logging messages. 

157 

158 Yields 

159 ------ 

160 sky_coords : `lsst.geom.SpherePoint` 

161 RA/Dec coordinates of the source. 

162 pixel_coords : `lsst.geom.Point2D` 

163 Pixel coordinates of the source. 

164 draw_size : `int` 

165 Size of the stamp to draw. 

166 object : `galsim.gsobject.GSObject` 

167 A fully specified and transformed GalSim object. 

168 """ 

169 if logger: 

170 source_types = Counter(injection_catalog["source_type"]) # type: ignore 

171 grammar0 = "source" if len(injection_catalog) == 1 else "sources" 

172 grammar1 = "type" if len(source_types) == 1 else "types" 

173 logger.info( 

174 "Generating %d injection %s consisting of %d unique %s: %s.", 

175 len(injection_catalog), 

176 grammar0, 

177 len(source_types), 

178 grammar1, 

179 ", ".join(f"{k}({v})" for k, v in source_types.items()), 

180 ) 

181 for source_data_full in injection_catalog: 

182 items = dict(source_data_full).items() # type: ignore 

183 source_data = {k: v for (k, v) in items if v is not ma.masked} 

184 try: 

185 sky_coords = SpherePoint(float(source_data["ra"]), float(source_data["dec"]), degrees) 

186 except KeyError: 

187 sky_coords = wcs.pixelToSky(float(source_data["x"]), float(source_data["y"])) 

188 try: 

189 pixel_coords = Point2D(source_data["x"], source_data["y"]) 

190 except KeyError: 

191 pixel_coords = wcs.skyToPixel(sky_coords) 

192 try: 

193 inst_flux = photo_calib.magnitudeToInstFlux(source_data["mag"], pixel_coords) 

194 except LogicError: 

195 continue 

196 try: 

197 draw_size = int(source_data["draw_size"]) 

198 except KeyError: 

199 draw_size = 0 

200 

201 if source_data["source_type"] == "Stamp": 

202 stamp_file = stamp_prefix + source_data["stamp"] 

203 object = make_galsim_stamp(stamp_file, fits_alignment, wcs, sky_coords, inst_flux) 

204 elif source_data["source_type"] == "Trail": 

205 object = make_galsim_trail(source_data, wcs, sky_coords, inst_flux) 

206 else: 

207 object = make_galsim_object(source_data, source_data["source_type"], inst_flux) 

208 

209 yield sky_coords, pixel_coords, draw_size, object 

210 

211 

212def make_galsim_object( 

213 source_data: dict[str, Any], 

214 source_type: str, 

215 inst_flux: float, 

216) -> galsim.gsobject.GSObject: 

217 """Make a generic GalSim object from a collection of source data. 

218 

219 Parameters 

220 ---------- 

221 source_data : `dict` [`str`, `Any`] 

222 Dictionary of source data. 

223 source_type : `str` 

224 Type of the source, corresponding to a GalSim class. 

225 inst_flux : `float` 

226 Instrumental flux of the source. 

227 

228 Returns 

229 ------- 

230 object : `galsim.gsobject.GSObject` 

231 A fully specified and transformed GalSim object. 

232 """ 

233 # Populate the non-shear and non-flux parameters. 

234 object_class = getattr(galsim, source_type) 

235 object_data = get_object_data(source_data, object_class) 

236 object = object_class(**object_data) 

237 # Create a version of the object with an area-preserving shear applied. 

238 shear_data = get_shear_data(source_data) 

239 try: 

240 object = object.shear(**shear_data) 

241 except TypeError: 

242 pass 

243 # Apply the instrumental flux and return. 

244 object = object.withFlux(inst_flux) 

245 return object 

246 

247 

248def make_galsim_trail( 

249 source_data: dict[str, Any], 

250 wcs: SkyWcs, 

251 sky_coords: SpherePoint, 

252 inst_flux: float, 

253 trail_thickness: float = 1e-6, 

254) -> galsim.gsobject.GSObject: 

255 """Make a trail with GalSim from a collection of source data. 

256 

257 Parameters 

258 ---------- 

259 source_data : `dict` [`str`, `Any`] 

260 Dictionary of source data. 

261 wcs : `lsst.afw.geom.SkyWcs` 

262 World coordinate system. 

263 sky_coords : `lsst.geom.SpherePoint` 

264 Sky coordinates of the source. 

265 inst_flux : `float` 

266 Instrumental flux of the source. 

267 trail_thickness : `float` 

268 Thickness of the trail in pixels. 

269 

270 Returns 

271 ------- 

272 object : `galsim.gsobject.GSObject` 

273 A fully specified and transformed GalSim object. 

274 """ 

275 # Make a 'thin' box to mimic a line surface brightness profile of default 

276 # thickness = 1e-6 (i.e., much thinner than a pixel) 

277 object = galsim.Box(source_data["trail_length"], trail_thickness) 

278 try: 

279 object = object.rotate(source_data["beta"] * galsim.degrees) 

280 except KeyError: 

281 pass 

282 object = object.withFlux(inst_flux * source_data["trail_length"]) # type: ignore 

283 # GalSim objects are assumed to be in sky-coords. As we want the trail to 

284 # appear as defined above in image-coords, we must transform it here. 

285 linear_wcs = wcs.linearizePixelToSky(sky_coords, arcseconds) 

286 mat = linear_wcs.getMatrix() 

287 object = object.transform(mat[0, 0], mat[0, 1], mat[1, 0], mat[1, 1]) 

288 return object # type: ignore 

289 

290 

291def make_galsim_stamp( 

292 stamp_file: str, 

293 fits_alignment: str, 

294 wcs: SkyWcs, 

295 sky_coords: SpherePoint, 

296 inst_flux: float, 

297) -> galsim.gsobject.GSObject: 

298 """Make a postage stamp with GalSim from a FITS file. 

299 

300 Parameters 

301 ---------- 

302 stamp_file : `str` 

303 Path to the FITS file containing the postage stamp. 

304 fits_alignment : `str` 

305 Alignment of the FITS image to the WCS. Allowed values: "wcs", "pixel". 

306 wcs : `lsst.afw.geom.SkyWcs` 

307 World coordinate system. 

308 sky_coords : `lsst.geom.SpherePoint` 

309 Sky coordinates of the source. 

310 inst_flux : `float` 

311 Instrumental flux of the source. 

312 

313 Returns 

314 ------- 

315 object: `galsim.gsobject.GSObject` 

316 A fully specified and transformed GalSim object. 

317 """ 

318 stamp_file = stamp_file.strip() 

319 if os.path.exists(stamp_file): 

320 with fits.open(stamp_file) as hdul: 

321 hdu_images = [hdu.is_image and hdu.size > 0 for hdu in hdul] # type: ignore 

322 if any(hdu_images): 

323 stamp_data = galsim.fits.read(stamp_file, read_header=True, hdu=np.where(hdu_images)[0][0]) 

324 else: 

325 raise RuntimeError(f"Cannot find image in input FITS file {stamp_file}.") 

326 match fits_alignment: 

327 case "wcs": 

328 # galsim.fits.read will always attach a WCS to its output. 

329 # If it can't find a WCS in the FITS header, then it 

330 # defaults to scale = 1.0 arcsec / pix. If that's the scale 

331 # then we need to check if it was explicitly set or if it's 

332 # just the default. If it's just the default then we should 

333 # raise an exception. 

334 if is_wcs_galsim_default(stamp_data.wcs, stamp_data.header): # type: ignore 

335 raise RuntimeError(f"Cannot find WCS in input FITS file {stamp_file}") 

336 case "pixel": 

337 # We need to set stamp_data.wcs to the local target WCS. 

338 linear_wcs = wcs.linearizePixelToSky(sky_coords, arcseconds) 

339 mat = linear_wcs.getMatrix() 

340 stamp_data.wcs = galsim.JacobianWCS( # type: ignore 

341 mat[0, 0], mat[0, 1], mat[1, 0], mat[1, 1] 

342 ) 

343 object = galsim.InterpolatedImage(stamp_data, calculate_stepk=False) 

344 object = object.withFlux(inst_flux) 

345 return object # type: ignore 

346 else: 

347 raise RuntimeError(f"Cannot locate input FITS postage stamp {stamp_file}.") 

348 

349 

350def is_wcs_galsim_default( 

351 wcs: galsim.fitswcs.GSFitsWCS, 

352 hdr: galsim.fits.FitsHeader, 

353) -> bool: 

354 """Decide if wcs = galsim.PixelScale(1.0) is explicitly present in header, 

355 or if it's just the GalSim default. 

356 

357 Parameters 

358 ---------- 

359 wcs : galsim.fitswcs.GSFitsWCS 

360 Potentially default WCS. 

361 hdr : galsim.fits.FitsHeader 

362 Header as read in by GalSim. 

363 

364 Returns 

365 ------- 

366 is_default : bool 

367 True if default, False if explicitly set in header. 

368 """ 

369 if wcs != galsim.PixelScale(1.0): 

370 return False 

371 if hdr.get("GS_WCS") is not None: 

372 return False 

373 if hdr.get("CTYPE1", "LINEAR") == "LINEAR": 

374 return not any(k in hdr for k in ["CD1_1", "CDELT1"]) 

375 for wcs_type in galsim.fitswcs.fits_wcs_types: 

376 # If one of these succeeds, then assume result is explicit. 

377 try: 

378 wcs_type._readHeader(hdr) 

379 return False 

380 except Exception: 

381 pass 

382 else: 

383 return not any(k in hdr for k in ["CD1_1", "CDELT1"]) 

384 

385 

386def inject_galsim_objects_into_exposure( 

387 exposure: ExposureF, 

388 objects: Generator[tuple[SpherePoint, Point2D, int, galsim.gsobject.GSObject], None, None], 

389 mask_plane_name: str = "INJECTED", 

390 calib_flux_radius: float = 12.0, 

391 draw_size_max: int = 1000, 

392 logger: Any | None = None, 

393) -> tuple[list[int], list[galsim.BoundsI], list[bool], list[bool]]: 

394 """Inject sources into given exposure using GalSim. 

395 

396 Parameters 

397 ---------- 

398 exposure : `lsst.afw.image.ExposureF` 

399 The exposure to inject synthetic sources into. 

400 objects : `Generator` [`tuple`, None, None] 

401 An iterator of tuples that contains (or generates) locations and object 

402 surface brightness profiles to inject. The tuples should contain the 

403 following elements: `lsst.geom.SpherePoint`, `lsst.geom.Point2D`, 

404 `int`, `galsim.gsobject.GSObject`. 

405 mask_plane_name : `str` 

406 Name of the mask plane to use for the injected sources. 

407 calib_flux_radius : `float` 

408 Radius in pixels to use for the flux calibration. This is used to 

409 produce the correct instrumental fluxes within the radius. The value 

410 should match that of the field defined in slot_CalibFlux_instFlux. 

411 draw_size_max : `int` 

412 Maximum allowed size of the drawn object. If the object is larger than 

413 this, the draw size will be clipped to this size. 

414 logger : `lsst.utils.logging.LsstLogAdapter`, optional 

415 Logger to use for logging messages. 

416 

417 Returns 

418 ------- 

419 draw_sizes : `list` [`int`] 

420 Draw sizes of the injected sources. 

421 common_bounds : `list` [`galsim.BoundsI`] 

422 Common bounds of the drawn objects. 

423 fft_size_errors : `list` [`bool`] 

424 Boolean flags indicating whether a GalSimFFTSizeError was raised. 

425 psf_compute_errors : `list` [`bool`] 

426 Boolean flags indicating whether a PSF computation error was raised. 

427 """ 

428 exposure.mask.addMaskPlane(mask_plane_name) 

429 mask_plane_core_name = mask_plane_name + "_CORE" 

430 exposure.mask.addMaskPlane(mask_plane_core_name) 

431 if logger: 

432 logger.info( 

433 "Adding %s and %s mask planes to the exposure.", 

434 mask_plane_name, 

435 mask_plane_core_name, 

436 ) 

437 psf = exposure.getPsf() 

438 wcs = exposure.getWcs() 

439 bbox = exposure.getBBox() 

440 full_bounds = galsim.BoundsI(bbox.minX, bbox.maxX, bbox.minY, bbox.maxY) 

441 galsim_image = galsim.Image(exposure.image.array, bounds=full_bounds) 

442 pixel_scale = wcs.getPixelScale(bbox.getCenter()).asArcseconds() 

443 

444 draw_sizes: list[int] = [] 

445 common_bounds: list[galsim.BoundsI] = [] 

446 fft_size_errors: list[bool] = [] 

447 psf_compute_errors: list[bool] = [] 

448 for i, (sky_coords, pixel_coords, draw_size, object) in enumerate(objects): 

449 # Instantiate default returns in case of early exit from this loop. 

450 draw_sizes.append(0) 

451 common_bounds.append(galsim.BoundsI()) 

452 fft_size_errors.append(False) 

453 psf_compute_errors.append(False) 

454 

455 # Get spatial coordinates and WCS. 

456 posd = galsim.PositionD(pixel_coords.x, pixel_coords.y) 

457 posi = galsim.PositionI(pixel_coords.x // 1, pixel_coords.y // 1) 

458 if logger: 

459 logger.debug(f"Injecting synthetic source at {pixel_coords}.") 

460 mat = wcs.linearizePixelToSky(sky_coords, arcseconds).getMatrix() 

461 galsim_wcs = galsim.JacobianWCS(mat[0, 0], mat[0, 1], mat[1, 0], mat[1, 1]) 

462 

463 # This check is here because sometimes the WCS is multivalued and 

464 # objects that should not be included were being included. 

465 galsim_pixel_scale = np.sqrt(galsim_wcs.pixelArea()) 

466 if galsim_pixel_scale < pixel_scale / 2 or galsim_pixel_scale > pixel_scale * 2: 

467 continue 

468 

469 # Compute the PSF at the object location. 

470 try: 

471 psf_array = psf.computeKernelImage(pixel_coords).array 

472 except InvalidParameterError: 

473 # Try mapping to nearest point contained in bbox. 

474 contained_point = Point2D( 

475 np.clip(pixel_coords.x, bbox.minX, bbox.maxX), np.clip(pixel_coords.y, bbox.minY, bbox.maxY) 

476 ) 

477 if pixel_coords == contained_point: # no difference, so skip immediately 

478 psf_compute_errors[i] = True 

479 if logger: 

480 logger.debug("Cannot compute PSF for object at %s; flagging and skipping.", sky_coords) 

481 continue 

482 # Otherwise, try again with new point. 

483 try: 

484 psf_array = psf.computeKernelImage(contained_point).array 

485 except InvalidParameterError: 

486 psf_compute_errors[i] = True 

487 if logger: 

488 logger.debug("Cannot compute PSF for object at %s; flagging and skipping.", sky_coords) 

489 continue 

490 

491 # Compute the aperture corrected PSF interpolated image. 

492 aperture_correction = psf.computeApertureFlux(calib_flux_radius, psf.getAveragePosition()) 

493 psf_array /= aperture_correction 

494 galsim_psf = galsim.InterpolatedImage(galsim.Image(psf_array), wcs=galsim_wcs) 

495 

496 # Convolve the object with the PSF and generate draw size. 

497 conv = galsim.Convolve(object, galsim_psf) 

498 if draw_size == 0: 

499 draw_size = conv.getGoodImageSize(galsim_wcs.minLinearScale()) # type: ignore 

500 injection_draw_size = int(draw_size) 

501 injection_core_size = 3 

502 if draw_size_max > 0 and injection_draw_size > draw_size_max: 

503 if logger: 

504 logger.warning( 

505 "Clipping draw size for object at %s from %d to %d pixels.", 

506 sky_coords, 

507 injection_draw_size, 

508 draw_size_max, 

509 ) 

510 injection_draw_size = draw_size_max 

511 draw_sizes[i] = injection_draw_size 

512 if injection_core_size > injection_draw_size: 

513 if logger: 

514 logger.debug( 

515 "Clipping core size for object at %s from %d to %d pixels.", 

516 sky_coords, 

517 injection_core_size, 

518 injection_draw_size, 

519 ) 

520 injection_core_size = injection_draw_size 

521 sub_bounds = galsim.BoundsI(posi).withBorder(injection_draw_size // 2) 

522 object_common_bounds = full_bounds & sub_bounds 

523 common_bounds[i] = object_common_bounds # type: ignore 

524 

525 # Inject the source if there is any overlap. 

526 if object_common_bounds.area() > 0: 

527 common_image = galsim_image[object_common_bounds] 

528 offset = posd - object_common_bounds.true_center 

529 # Note, for calexp injection, pixel is already part of the PSF and 

530 # for coadd injection, it's incorrect to include the output pixel. 

531 # So for both cases, we draw using method='no_pixel'. 

532 try: 

533 conv.drawImage( 

534 common_image, add_to_image=True, offset=offset, wcs=galsim_wcs, method="no_pixel" 

535 ) 

536 except GalSimFFTSizeError as err: 

537 fft_size_errors[i] = True 

538 if logger: 

539 logger.debug( 

540 "GalSimFFTSizeError raised for object at %s; flagging and skipping.\n%s", 

541 sky_coords, 

542 err, 

543 ) 

544 continue 

545 common_box = Box2I( 

546 Point2I(object_common_bounds.xmin, object_common_bounds.ymin), 

547 Point2I(object_common_bounds.xmax, object_common_bounds.ymax), 

548 ) 

549 bitvalue = exposure.mask.getPlaneBitMask(mask_plane_name) 

550 exposure[common_box].mask.array |= bitvalue 

551 # Add a 3 x 3 pixel mask centered on the object. The mask must be 

552 # large enough to always identify the core/peak of the injected 

553 # source, but small enough that it rarely overlaps real sources. 

554 sub_bounds_core = galsim.BoundsI(posi).withBorder(injection_core_size // 2) 

555 object_common_bounds_core = full_bounds & sub_bounds_core 

556 if object_common_bounds_core.area() > 0: 

557 common_box_core = Box2I( 

558 Point2I(object_common_bounds_core.xmin, object_common_bounds_core.ymin), 

559 Point2I(object_common_bounds_core.xmax, object_common_bounds_core.ymax), 

560 ) 

561 bitvalue_core = exposure.mask.getPlaneBitMask(mask_plane_core_name) 

562 exposure[common_box_core].mask.array |= bitvalue_core 

563 else: 

564 if logger: 

565 logger.debug("No area overlap for object at %s; flagging and skipping.", sky_coords) 

566 

567 return draw_sizes, common_bounds, fft_size_errors, psf_compute_errors