Coverage for python / lsst / cell_coadds / _stitched_psf.py: 37%

104 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-23 08:47 +0000

1# This file is part of cell_coadds. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import pickle 

25from functools import partial 

26from typing import TYPE_CHECKING, Any, ClassVar 

27 

28import numpy as np 

29 

30import lsst.geom as geom 

31from lsst.afw.detection import InvalidPsfError 

32from lsst.afw.image import ImageD 

33from lsst.afw.typehandling import StorableHelperFactory 

34from lsst.meas.algorithms import ImagePsf 

35 

36from ._grid_container import GridContainer 

37from ._uniform_grid import UniformGrid 

38 

39if TYPE_CHECKING: 39 ↛ 40line 39 didn't jump to line 40 because the condition on line 39 was never true

40 from lsst.afw.image import Color 

41 

42__all__ = ("StitchedPsf",) 

43 

44 

45class StitchedPsf(ImagePsf): 

46 """A piecewise PSF implementation backed by a 2-d grid of images.""" 

47 

48 # We need to ensure a C++ StorableHelperFactory is constructed and 

49 # available before any unpersists of this class. Placing this "private" 

50 # class attribute here accomplishes that. 

51 _factory: ClassVar[type[StorableHelperFactory]] = StorableHelperFactory("lsst.cell_coadds", "StitchedPsf") 

52 

53 def __init__(self, images: GridContainer[ImageD], grid: UniformGrid) -> None: 

54 self._validate_args(images, grid) 

55 

56 super().__init__() 

57 self._images = images 

58 self._grid = grid 

59 self._averagePosition = None 

60 

61 @staticmethod 

62 def _validate_args(images: GridContainer[ImageD], grid: UniformGrid) -> None: 

63 """Validate the images and grid. 

64 

65 Parameters 

66 ---------- 

67 images : `GridContainer` 

68 The images to validate. 

69 grid : `UniformGrid` 

70 The grid to validate. 

71 

72 Raises 

73 ------ 

74 ValueError 

75 Raised if the images and grid are incompatible. 

76 """ 

77 min_x = min(index.x for index in images.indices()) 

78 min_y = min(index.y for index in images.indices()) 

79 max_x = max(index.x for index in images.indices()) 

80 max_y = max(index.y for index in images.indices()) 

81 

82 if ((max_x - min_x + 1) > grid.shape.x) or ((max_y - min_y + 1) > grid.shape.y): 

83 raise ValueError("Images do not fit on grid.") 

84 

85 @property 

86 def images(self) -> GridContainer[ImageD]: 

87 """The images that make up this PSF.""" 

88 return self._images 

89 

90 @property 

91 def grid(self) -> UniformGrid: 

92 """The grid on which the images are placed.""" 

93 return self._grid 

94 

95 def getAveragePosition(self) -> geom.Point2D: 

96 """Get a position where PSF can be evaluated on a patch. 

97 

98 This defaults to the center of the patch bounding box, unless there are 

99 no inputs there. In that case, it switches to find an arbitrary cell, 

100 typically at a corner that has inputs and returns the center position 

101 of the cell. 

102 """ 

103 if self._averagePosition is None: 

104 center = self._grid.bbox.getCenter() 

105 if self.grid.index(geom.Point2I(center)) not in self._images: 

106 arbitrary_index = next(iter(self._images)) 

107 bbox = self._grid.bbox_of(arbitrary_index) 

108 center = bbox.getCenter() 

109 

110 self._averagePosition = center 

111 

112 return self._averagePosition 

113 

114 # The _do* methods make use of the ImagePsf trampoline. 

115 def _doComputeBBox(self, position: geom.Point2D | geom.Point2I, color: Color = None) -> geom.Box2I: 

116 try: 

117 return self._images[self._grid.index(geom.Point2I(position))].getBBox() 

118 except (KeyError, ValueError): 

119 raise InvalidPsfError("No inputs exists at position.") from None 

120 

121 def _doComputeKernelImage(self, position: geom.Point2D | geom.Point2I, color: Color = None) -> ImageD: 

122 try: 

123 return self._images[self._grid.index(geom.Point2I(position))] 

124 except (KeyError, ValueError): 

125 raise InvalidPsfError("No inputs exists at position.") from None 

126 

127 def clone(self) -> StitchedPsf: 

128 """Return a deep copy of this object.""" 

129 return StitchedPsf(self.images, self.grid) 

130 

131 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> StitchedPsf: 

132 """Return a deep copy of this object.""" 

133 return StitchedPsf(self.images, self.grid) 

134 

135 def __eq__(self, other: object) -> bool: 

136 if not isinstance(other, StitchedPsf): 

137 return False 

138 

139 if not (self.grid == other.grid): 

140 return False 

141 

142 for index in self.images.indices(): 

143 if ( 

144 not (self.images[index].array.shape == other.images[index].array.shape) 

145 or not np.equal(self.images[index].array, other.images[index].array).all() 

146 ): 

147 return False 

148 

149 return True 

150 

151 @staticmethod 

152 def _callback(image: ImageD, bbox: geom.Box2I) -> ImageD: 

153 if image.getBBox().contains(bbox): 

154 return image[bbox] 

155 else: 

156 # Make a new image big enough to fit current bbox and new bbox, 

157 # copy current image into it, then subset that for the returned 

158 # PSF. 

159 bigger_image = ImageD(bbox=bbox.expandedTo(image.getBBox()), initialValue=0.0) 

160 bigger_image[image.getBBox()] = image 

161 return bigger_image[bbox] 

162 

163 def resized(self, width: int, height: int) -> StitchedPsf: 

164 if not (width % 2 == 1 and width > 0): 

165 raise ValueError("resized width must be a positive odd integer; got {width}.") 

166 if not (height % 2 == 1 and height > 0): 

167 raise ValueError("resized height must be a positive odd integer; got {height}.") 

168 

169 bbox = geom.Box2I(geom.Point2I(-(width // 2), -(height // 2)), geom.Extent2I(width, height)) 

170 gc = self._images.rebuild_transformed(transform=partial(self._callback, bbox=bbox)) 

171 return StitchedPsf(gc, self.grid) 

172 

173 @staticmethod 

174 def isPersistable() -> bool: 

175 return True 

176 

177 @staticmethod 

178 def _getPersistenceName() -> str: 

179 return "StitchedPsf" 

180 

181 @staticmethod 

182 def _getPythonModule() -> str: 

183 return __name__ 

184 

185 # The get/set state methods are needed to support pickle. 

186 def __getstate__(self) -> dict: 

187 return {"images": self.images, "grid": self.grid} 

188 

189 def __setstate__(self, state: dict) -> None: 

190 StitchedPsf.__init__(self, state["images"], state["grid"]) 

191 

192 def _write(self) -> bytes: 

193 return pickle.dumps((self._images, self._grid)) 

194 

195 @staticmethod 

196 def _read(pkl: bytes) -> StitchedPsf: 

197 return StitchedPsf(*pickle.loads(pkl)) 

198 

199 def writeFits(self, name: str) -> None: 

200 """Persist the PSF as a FITS file.""" 

201 raise NotImplementedError("FITS persistence not implemented for StitchedPsf.") 

202 

203 def readFits(self, name: str) -> None: 

204 """Persist the PSF as a FITS file.""" 

205 raise NotImplementedError("FITS persistence not implemented for StitchedPsf.")