Coverage for python/lsst/cell_coadds/_stitched_psf.py: 47%
87 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-26 03:49 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-26 03:49 -0700
1# This file is part of cell_coadds.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24import pickle
25from functools import partial
26from typing import TYPE_CHECKING, Any, ClassVar
28import lsst.geom as geom
29import numpy as np
30from lsst.afw.image import ImageD
31from lsst.afw.typehandling import StorableHelperFactory
32from lsst.meas.algorithms import ImagePsf
34from ._grid_container import GridContainer
35from ._uniform_grid import UniformGrid
37if TYPE_CHECKING: 37 ↛ 38line 37 didn't jump to line 38, because the condition on line 37 was never true
38 from lsst.afw.image import Color
40__all__ = ("StitchedPsf",)
43class StitchedPsf(ImagePsf):
44 """A piecewise PSF implementation backed by a 2-d grid of images."""
46 # We need to ensure a C++ StorableHelperFactory is constructed and
47 # available before any unpersists of this class. Placing this "private"
48 # class attribute here accomplishes that.
49 _factory: ClassVar[type[StorableHelperFactory]] = StorableHelperFactory("lsst.cell_coadds", "StitchedPsf")
51 def __init__(self, images: GridContainer[ImageD], grid: UniformGrid) -> None:
52 self._validate_args(images, grid)
54 super().__init__()
55 self._images = images
56 self._grid = grid
58 @staticmethod
59 def _validate_args(images: GridContainer[ImageD], grid: UniformGrid) -> None:
60 """Validate the images and grid.
62 Parameters
63 ----------
64 images : `GridContainer`
65 The images to validate.
66 grid : `UniformGrid`
67 The grid to validate.
69 Raises
70 ------
71 ValueError
72 Raised if the images and grid are incompatible.
73 """
74 min_x = min(index.x for index in images.indices())
75 min_y = min(index.y for index in images.indices())
76 max_x = max(index.x for index in images.indices())
77 max_y = max(index.y for index in images.indices())
79 if ((max_x - min_x + 1) > grid.shape.x) or ((max_y - min_y + 1) > grid.shape.y):
80 raise ValueError("Images do not fit on grid.")
82 @property
83 def images(self) -> GridContainer[ImageD]:
84 """The images that make up this PSF."""
85 return self._images
87 @property
88 def grid(self) -> UniformGrid:
89 """The grid on which the images are placed."""
90 return self._grid
92 # The _do* methods make use of the ImagePsf trampoline.
93 def _doComputeBBox(self, position: geom.Point2D | geom.Point2I, color: Color = None) -> geom.Box2I:
94 return self._images[self._grid.index(geom.Point2I(position))].getBBox()
96 def _doComputeKernelImage(self, position: geom.Point2D | geom.Point2I, color: Color = None) -> ImageD:
97 return self._images[self._grid.index(geom.Point2I(position))]
99 def clone(self) -> StitchedPsf:
100 """Return a deep copy of this object."""
101 return StitchedPsf(self.images, self.grid)
103 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> StitchedPsf:
104 """Return a deep copy of this object."""
105 return StitchedPsf(self.images, self.grid)
107 def __eq__(self, other: object) -> bool:
108 if not isinstance(other, StitchedPsf):
109 return False
111 if not (self.grid == other.grid):
112 return False
114 for index in self.images.indices():
115 if (
116 not (self.images[index].array.shape == other.images[index].array.shape)
117 or not np.equal(self.images[index].array, other.images[index].array).all()
118 ):
119 return False
121 return True
123 @staticmethod
124 def _callback(image: ImageD, bbox: geom.Box2I) -> ImageD:
125 if image.getBBox().contains(bbox):
126 return image[bbox]
127 else:
128 # Make a new image big enough to fit current bbox and new bbox,
129 # copy current image into it, then subset that for the returned
130 # PSF.
131 bigger_image = ImageD(bbox=bbox.expandedTo(image.getBBox()), initialValue=0.0)
132 bigger_image[image.getBBox()] = image
133 return bigger_image[bbox]
135 def resized(self, width: int, height: int) -> StitchedPsf:
136 if not (width % 2 == 1 and width > 0):
137 raise ValueError("resized width must be a positive odd integer; got {width}.")
138 if not (height % 2 == 1 and height > 0):
139 raise ValueError("resized height must be a positive odd integer; got {height}.")
141 bbox = geom.Box2I(geom.Point2I(-(width // 2), -(height // 2)), geom.Extent2I(width, height))
142 gc = self._images.rebuild_transformed(transform=partial(self._callback, bbox=bbox))
143 return StitchedPsf(gc, self.grid)
145 @staticmethod
146 def isPersistable() -> bool:
147 return True
149 @staticmethod
150 def _getPersistenceName() -> str:
151 return "StitchedPsf"
153 @staticmethod
154 def _getPythonModule() -> str:
155 return __name__
157 # The get/set state methods are needed to support pickle.
158 def __getstate__(self) -> dict:
159 return {"images": self.images, "grid": self.grid}
161 def __setstate__(self, state: dict) -> None:
162 StitchedPsf.__init__(self, state["images"], state["grid"])
164 def _write(self) -> bytes:
165 return pickle.dumps((self._images, self._grid))
167 @staticmethod
168 def _read(pkl: bytes) -> StitchedPsf:
169 return StitchedPsf(*pickle.loads(pkl))
171 def writeFits(self, name: str) -> None:
172 """Persist the PSF as a FITS file."""
173 raise NotImplementedError("FITS persistence not implemented for StitchedPsf.")
175 def readFits(self, name: str) -> None:
176 """Persist the PSF as a FITS file."""
177 raise NotImplementedError("FITS persistence not implemented for StitchedPsf.")