Coverage for tests / test_getTemplate.py: 20%
162 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 09:17 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 09:17 +0000
1# This file is part of ip_diffim.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import collections
23import itertools
24import unittest
26import numpy as np
28import lsst.afw.geom
29import lsst.afw.image
30import lsst.afw.math
31from lsst.daf.butler import DataCoordinate, DimensionUniverse
32import lsst.geom
33import lsst.ip.diffim
34import lsst.meas.algorithms
35import lsst.meas.base.tests
36import lsst.pipe.base as pipeBase
37import lsst.skymap
38import lsst.utils.tests
40from utils import generate_data_id
42# Change this to True, `setup display_ds9`, and open ds9 (or use another afw
43# display backend) to show the tract/patch layouts on the image.
44debug = False
45if debug: 45 ↛ 46line 45 didn't jump to line 46 because the condition on line 45 was never true
46 import lsst.afw.display
47 display = lsst.afw.display.Display()
48 display.frame = 1
51def _showTemplate(box, template):
52 """Show the corners of the template we made in this test."""
53 for point in box.getCorners():
54 display.dot("+", point.x, point.y, ctype="orange", size=40)
55 display.frame = 2
56 display.image(template, "warped template")
57 display.frame = 3
58 display.image(template.variance, "warped variance")
61class GetTemplateTaskTestCase(lsst.utils.tests.TestCase):
62 """Test that GetTemplateTask works on both one tract and multiple tract
63 input coadd exposures.
65 Makes a synthetic exposure large enough to fit four small tracts with 2x2
66 (300x300 pixel) patches each, extracts pixels for those patches by warping,
67 and tests GetTemplateTask's output against boxes that overlap various
68 combinations of one or multiple tracts.
69 """
70 def setUp(self):
71 self.scale = 0.2 # arcsec/pixel
72 self.skymap = self._makeSkymap()
73 self.patches = collections.defaultdict(list)
74 self.dataIds = collections.defaultdict(list)
75 self.exposure = self._makeExposure()
77 if debug:
78 display.image(self.exposure, "base exposure")
80 for tract_id in range(4):
81 tract = self.skymap.generateTract(tract_id)
82 self._makePatches(tract)
84 def _makeSkymap(self):
85 """Make a Skymap with 4 tracts with 4 patches each.
86 """
87 tractScale = 0.02 # degrees
88 # On-sky coordinates of the tract centers.
89 coords = [(0, 0),
90 (0, tractScale),
91 (tractScale, 0),
92 (tractScale, tractScale),
93 ]
94 config = lsst.skymap.DiscreteSkyMap.ConfigClass()
95 config.raList = [c[0] for c in coords]
96 config.decList = [c[1] for c in coords]
97 # Half the tract center step size, to keep the tract overlap small.
98 config.radiusList = [tractScale/2 for c in coords]
99 config.projection = "TAN"
100 config.pixelScale = self.scale
101 config.tractOverlap = 0.0005
102 config.tractBuilder = "legacy"
103 config.tractBuilder["legacy"].patchInnerDimensions = (300, 300)
104 config.tractBuilder["legacy"].patchBorder = 10
105 return lsst.skymap.DiscreteSkyMap(config=config)
107 def _makeExposure(self):
108 """Create a large image to break up into tracts and patches.
110 The image will have a source every 100 pixels in x and y, and a WCS
111 that results in the tracts all fitting in the image, with tract=0
112 in the lower left, tract=1 to the right, tract=2 above, and tract=3
113 to the upper right.
114 """
115 box = lsst.geom.Box2I(lsst.geom.Point2I(-200, -200), lsst.geom.Point2I(800, 800))
116 # This WCS was constructed so that tract 0 mostly fills the lower left
117 # quadrant of the image, and the other tracts fill the rest; slight
118 # extra rotation as a check on the final warp layout, scaled by 5%
119 # from the patch pixel scale.
120 cd_matrix = lsst.afw.geom.makeCdMatrix(1.05*self.scale*lsst.geom.arcseconds, 93*lsst.geom.degrees)
121 wcs = lsst.afw.geom.makeSkyWcs(lsst.geom.Point2D(120, 150),
122 lsst.geom.SpherePoint(0, 0, lsst.geom.radians),
123 cd_matrix)
124 dataset = lsst.meas.base.tests.TestDataset(box, wcs=wcs)
125 for x, y in itertools.product(np.arange(0, 500, 100), np.arange(0, 500, 100)):
126 dataset.addSource(1e5, lsst.geom.Point2D(x, y))
127 exposure, _ = dataset.realize(2, dataset.makeMinimalSchema())
128 exposure.setFilter(lsst.afw.image.FilterLabel("a", "a_test"))
129 return exposure
131 def _makePatches(self, tract):
132 """Populate the patches and dataId dicts, keyed on tract id, with the
133 warps of the main exposure and minimal dataIds, respectively.
134 """
135 if debug:
136 color = ['red', 'green', 'cyan', 'yellow'][tract.tract_id]
137 point = self.exposure.wcs.skyToPixel(tract.ctr_coord)
138 # Show the tract center, colored by tract id.
139 display.dot("x", point.x, point.y, ctype=color, size=30)
141 # Use 5th order to minimize artifacts on the templates.
142 config = lsst.afw.math.Warper.ConfigClass()
143 config.warpingKernelName = "lanczos5"
144 warper = lsst.afw.math.Warper.fromConfig(config)
145 for patchId in range(tract.num_patches.x*tract.num_patches.y):
146 patch = tract.getPatchInfo(patchId)
147 box = patch.getOuterBBox()
149 if debug:
150 # Show the patch corners as patch ids, colored by tract id.
151 points = self.exposure.wcs.skyToPixel(patch.wcs.pixelToSky([lsst.geom.Point2D(x)
152 for x in box.getCorners()]))
153 for p in points:
154 display.dot(patchId, p.x, p.y, ctype=color)
156 # This is mostly taken from drp_tasks makePsfMatchedWarp, but
157 # ip_diffim cannot depend on drp_tasks.
158 xyTransform = lsst.afw.geom.makeWcsPairTransform(self.exposure.wcs, patch.wcs)
159 warpedPsf = lsst.meas.algorithms.WarpedPsf(self.exposure.psf, xyTransform)
160 warped = warper.warpExposure(patch.wcs, self.exposure, destBBox=box)
161 warped.setPsf(warpedPsf)
162 dataRef = pipeBase.InMemoryDatasetHandle(
163 warped,
164 storageClass="ExposureF",
165 copy=True,
166 dataId=generate_data_id(
167 tract=tract,
168 patch=patch,
169 )
170 )
171 self.patches[tract.tract_id].append(dataRef)
172 dataCoordinate = DataCoordinate.standardize({"tract": tract.tract_id,
173 "patch": patchId,
174 "band": "a",
175 "skymap": "skymap"},
176 universe=DimensionUniverse())
177 self.dataIds[tract.tract_id].append(dataCoordinate)
179 def _checkMetadata(self, template, config, box, wcs, nPsfs):
180 """Check that the various metadata components were set correctly.
181 """
182 expectedBox = lsst.geom.Box2I(box)
183 expectedBox.grow(config.templateBorderSize)
184 self.assertEqual(template.getBBox(), expectedBox)
185 # WCS should match our exposure, not any of the coadd tracts.
186 for tract in self.patches:
187 self.assertNotEqual(template.wcs, self.patches[tract][0].get().wcs)
188 self.assertEqual(template.wcs, self.exposure.wcs)
189 self.assertEqual(template.photoCalib, self.exposure.photoCalib)
190 self.assertEqual(template.getXY0(), expectedBox.getMin())
191 self.assertEqual(template.filter.bandLabel, "a")
192 self.assertEqual(template.filter.physicalLabel, "a_test")
193 self.assertEqual(template.psf.getComponentCount(), nPsfs)
195 def _checkPixels(self, template, config, box):
196 """Check that the pixel values in the template are close to the
197 original image.
198 """
199 # All pixels should have real values!
200 expectedBox = lsst.geom.Box2I(box)
201 expectedBox.grow(config.templateBorderSize)
203 if debug:
204 _showTemplate(expectedBox, template)
206 # Check that we fully filled the template from the patches.
207 self.assertTrue(np.all(np.isfinite(template.image.array)))
208 # Because of the scale changes, there will be some ringing in the
209 # difference between the template and the original image; pick
210 # tolerances large enough to account for that.
211 self.assertImagesAlmostEqual(template.image, self.exposure[expectedBox].image,
212 rtol=.1, atol=4)
213 # Variance plane ==2 in the original image, but the warped images will
214 # have some structure due to the warping.
215 self.assertImagesAlmostEqual(template.variance, self.exposure[expectedBox].variance,
216 rtol=0.55, msg="variance planes differ")
217 # Not checking the mask, as warping changes the sizes of the masks.
219 def testRunOneTractInput(self):
220 """Test a bounding box that fully fits inside one tract, with only
221 that tract passed as input. This checks that the code handles a single
222 tract input correctly.
223 """
224 box = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Point2I(180, 180))
225 task = lsst.ip.diffim.GetTemplateTask()
226 # Restrict to tract 0, since the box fits in just that tract.
227 # Task modifies the input bbox, so pass a copy.
228 result = task.run(coaddExposureHandles={0: self.patches[0]},
229 bbox=lsst.geom.Box2I(box),
230 wcs=self.exposure.wcs,
231 dataIds={0: self.dataIds[0]},
232 physical_filter="a_test")
234 # All 4 patches from tract 0 are included in this template.
235 self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 4)
236 self._checkPixels(result.template, task.config, box)
238 def testRunOneTractMultipleInputs(self):
239 """Test a bounding box that fully fits inside one tract but where
240 multiple tracts were passed in. This checks that patches that are
241 mostly NaN after warping are merged correctly in the output.
242 """
243 box = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Point2I(180, 180))
244 task = lsst.ip.diffim.GetTemplateTask()
245 # Task modifies the input bbox, so pass a copy.
246 result = task.run(coaddExposureHandles=self.patches,
247 bbox=lsst.geom.Box2I(box),
248 wcs=self.exposure.wcs,
249 dataIds=self.dataIds,
250 physical_filter="a_test")
252 # All 4 patches from two tracts are included in this template.
253 self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 6)
254 self._checkPixels(result.template, task.config, box)
256 def testRunTwoTracts(self):
257 """Test a bounding box that crosses tract boundaries.
258 """
259 box = lsst.geom.Box2I(lsst.geom.Point2I(200, 200), lsst.geom.Point2I(600, 600))
260 task = lsst.ip.diffim.GetTemplateTask()
261 # Task modifies the input bbox, so pass a copy.
262 result = task.run(coaddExposureHandles=self.patches,
263 bbox=lsst.geom.Box2I(box),
264 wcs=self.exposure.wcs,
265 dataIds=self.dataIds,
266 physical_filter="a_test")
268 # All 4 patches from all 4 tracts are included in this template
269 self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 9)
270 self._checkPixels(result.template, task.config, box)
272 def testRunNoTemplate(self):
273 """A bounding box that doesn't overlap the patches will raise.
274 """
275 box = lsst.geom.Box2I(lsst.geom.Point2I(1200, 1200), lsst.geom.Point2I(1600, 1600))
276 task = lsst.ip.diffim.GetTemplateTask()
277 with self.assertRaisesRegex(lsst.pipe.base.NoWorkFound, "No patches found"):
278 task.run(coaddExposureHandles=self.patches,
279 bbox=lsst.geom.Box2I(box),
280 wcs=self.exposure.wcs,
281 dataIds=self.dataIds,
282 physical_filter="a_test")
284 def testMissingPatches(self):
285 """Test that a missing patch results in an appropriate mask.
287 This fixes the bug reported on DM-44997 (image and variance were NaN
288 but the mask was not set to NO_DATA for those pixels).
289 """
290 # tract=0, patch=1 is the lower-left corner, as displayed in DS9.
291 self.patches[0].pop(1)
292 box = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Point2I(180, 180))
293 task = lsst.ip.diffim.GetTemplateTask()
294 # Task modifies the input bbox, so pass a copy.
295 result = task.run(coaddExposureHandles=self.patches,
296 bbox=lsst.geom.Box2I(box),
297 wcs=self.exposure.wcs,
298 dataIds=self.dataIds,
299 physical_filter="a_test")
300 no_data = (result.template.mask.array & result.template.mask.getPlaneBitMask("NO_DATA")) != 0
301 self.assertTrue(np.isfinite(result.template.image.array).all())
302 self.assertTrue(np.isfinite(result.template.variance.array).all())
303 self.assertEqual(no_data.sum(), 20990)
305 @lsst.utils.tests.methodParameters(
306 box=[
307 lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Point2I(180, 180)),
308 lsst.geom.Box2I(lsst.geom.Point2I(200, 200), lsst.geom.Point2I(600, 600)),
309 ],
310 nInput=[8, 16],
311 )
312 def testNanInputs(self, box=None, nInput=None):
313 """Test that the template has finite values when some of the input
314 pixels have NaN as variance.
315 """
316 for tract, patchRefs in self.patches.items():
317 for patchRef in patchRefs:
318 patchCoadd = patchRef.get()
319 bbox = lsst.geom.Box2I()
320 bbox.include(lsst.geom.Point2I(patchCoadd.getBBox().getCenter()))
321 bbox.grow(3)
322 patchCoadd.variance[bbox].array *= np.nan
324 box = lsst.geom.Box2I(lsst.geom.Point2I(200, 200), lsst.geom.Point2I(600, 600))
325 task = lsst.ip.diffim.GetTemplateTask()
326 result = task.run(coaddExposureHandles=self.patches,
327 bbox=lsst.geom.Box2I(box),
328 wcs=self.exposure.wcs,
329 dataIds=self.dataIds,
330 physical_filter="a_test")
331 if debug:
332 _showTemplate(box, result.template)
333 self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 9)
334 # We just check that the pixel values are all finite. We cannot check that pixel values
335 # in the template are closer to the original anymore.
336 self.assertTrue(np.isfinite(result.template.image.array).all())
339def setup_module(module):
340 lsst.utils.tests.init()
343class MemoryTestCase(lsst.utils.tests.MemoryTestCase):
344 pass
347if __name__ == "__main__": 347 ↛ 348line 347 didn't jump to line 348 because the condition on line 347 was never true
348 lsst.utils.tests.init()
349 unittest.main()