Coverage for tests/test_deblend.py: 13%
160 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-13 11:55 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-13 11:55 +0000
1# This file is part of meas_extensions_scarlet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import unittest
24import numpy as np
25from numpy.testing import assert_almost_equal
27from lsst.geom import Point2I, Point2D
28import lsst.utils.tests
29import lsst.afw.image as afwImage
30from lsst.meas.algorithms import SourceDetectionTask
31from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask
32from lsst.meas.extensions.scarlet.utils import bboxToScarletBox, scarletBoxToBBox
33from lsst.meas.extensions.scarlet.io import monochromaticDataToScarlet, updateCatalogFootprints
34import lsst.scarlet.lite as scl
35from lsst.afw.table import SourceCatalog
36from lsst.afw.detection import Footprint
37from lsst.afw.geom import SpanSet, Stencil
39from utils import initData
42class TestDeblend(lsst.utils.tests.TestCase):
43 def setUp(self):
44 # Set the random seed so that the noise field is unaffected
45 np.random.seed(0)
46 shape = (5, 100, 115)
47 coords = [
48 # blend
49 (15, 25), (10, 30), (17, 38),
50 # isolated source
51 (85, 90),
52 ]
53 amplitudes = [
54 # blend
55 80, 60, 90,
56 # isolated source
57 20,
58 ]
59 result = initData(shape, coords, amplitudes)
60 targetPsfImage, psfImages, images, channels, seds, morphs, psfs = result
61 B, Ny, Nx = shape
63 # Add some noise, otherwise the task will blow up due to
64 # zero variance
65 noise = 10*(np.random.rand(*images.shape).astype(np.float32)-.5)
66 images += noise
68 self.bands = "grizy"
69 _images = afwImage.MultibandMaskedImage.fromArrays(
70 self.bands,
71 images.astype(np.float32),
72 None,
73 noise**2
74 )
75 coadds = [afwImage.Exposure(img, dtype=img.image.array.dtype) for img in _images]
76 self.coadds = afwImage.MultibandExposure.fromExposures(self.bands, coadds)
77 for b, coadd in enumerate(self.coadds):
78 coadd.setPsf(psfs[b])
80 def _deblend(self, version):
81 schema = SourceCatalog.Table.makeMinimalSchema()
82 # Adjust config options to test skipping parents
83 config = ScarletDeblendTask.ConfigClass()
84 config.maxIter = 100
85 config.maxFootprintArea = 1000
86 config.maxNumberOfPeaks = 4
87 config.catchFailures = False
88 config.version = version
90 # Detect sources
91 detectionTask = SourceDetectionTask(schema=schema)
92 deblendTask = ScarletDeblendTask(schema=schema, config=config)
93 table = SourceCatalog.Table.make(schema)
94 detectionResult = detectionTask.run(table, self.coadds["r"])
95 catalog = detectionResult.sources
97 # Add a footprint that is too large
98 src = catalog.addNew()
99 halfLength = int(np.ceil(np.sqrt(config.maxFootprintArea) + 1))
100 ss = SpanSet.fromShape(halfLength, Stencil.BOX, offset=(50, 50))
101 bigfoot = Footprint(ss)
102 bigfoot.addPeak(50, 50, 100)
103 src.setFootprint(bigfoot)
105 # Add a footprint with too many peaks
106 src = catalog.addNew()
107 ss = SpanSet.fromShape(10, Stencil.BOX, offset=(75, 20))
108 denseFoot = Footprint(ss)
109 for n in range(config.maxNumberOfPeaks+1):
110 denseFoot.addPeak(70+2*n, 15+2*n, 10*n)
111 src.setFootprint(denseFoot)
113 # Run the deblender
114 catalog, modelData = deblendTask.run(self.coadds, catalog)
115 return catalog, modelData, config
117 def test_deblend_task(self):
118 catalog, modelData, config = self._deblend("lite")
120 # Attach the footprints in each band and compare to the full
121 # data model. This is done in each band, both with and without
122 # flux re-distribution to test all of the different possible
123 # options of loading catalog footprints.
124 for useFlux in [False, True]:
125 for band in self.bands:
126 bandIndex = self.bands.index(band)
127 coadd = self.coadds[band]
129 if useFlux:
130 imageForRedistribution = coadd
131 else:
132 imageForRedistribution = None
134 updateCatalogFootprints(
135 modelData,
136 catalog,
137 band=band,
138 imageForRedistribution=imageForRedistribution,
139 removeScarletData=False,
140 )
142 # Check that the number of deblended children is consistent
143 parents = catalog[catalog["parent"] == 0]
144 self.assertEqual(np.sum(catalog["deblend_nChild"]), len(catalog)-len(parents))
146 # Check that the models have not been cleared
147 # from the modelData
148 self.assertEqual(len(modelData.blends), np.sum(~parents["deblend_skipped"]))
150 for parent in parents:
151 children = catalog[catalog["parent"] == parent.get("id")]
152 # Check that nChild is set correctly
153 self.assertEqual(len(children), parent.get("deblend_nChild"))
154 # Check that parent columns are propagated
155 # to their children
156 for parentCol, childCol in config.columnInheritance.items():
157 np.testing.assert_array_equal(parent.get(parentCol), children[childCol])
159 children = catalog[catalog["parent"] != 0]
160 for child in children:
161 fp = child.getFootprint()
162 img = fp.extractImage(fill=0.0)
163 # Check that the flux at the center is correct.
164 # Note: this only works in this test image because the
165 # detected peak is in the same location as the
166 # scarlet peak.
167 # If the peak is shifted, the flux value will be correct
168 # but deblend_peak_center is not the correct location.
169 px = child.get("deblend_peak_center_x")
170 py = child.get("deblend_peak_center_y")
171 flux = img[Point2I(px, py)]
172 self.assertEqual(flux, child.get("deblend_peak_instFlux"))
174 # Check that the peak positions match the catalog entry
175 peaks = fp.getPeaks()
176 self.assertEqual(px, peaks[0].getIx())
177 self.assertEqual(py, peaks[0].getIy())
179 # Load the data to check against the HeavyFootprint
180 blendData = modelData.blends[child["parent"]]
181 # We need to set an observation in order to convolve
182 # the model.
183 position = Point2D(*blendData.psf_center[::-1])
184 _psfs = self.coadds[band].getPsf().computeKernelImage(position).array[None, :, :]
185 modelBox = scl.Box(blendData.shape, origin=blendData.origin)
186 observation = scl.Observation.empty(
187 bands=("dummy", ),
188 psfs=_psfs,
189 model_psf=modelData.psf[None, :, :],
190 bbox=modelBox,
191 dtype=np.float32,
192 )
193 blend = monochromaticDataToScarlet(
194 blendData=blendData,
195 bandIndex=bandIndex,
196 observation=observation,
197 )
198 # The stored PSF should be the same as the calculated one
199 assert_almost_equal(blendData.psf[bandIndex:bandIndex+1], _psfs)
201 # Get the scarlet model for the source
202 source = [src for src in blend.sources if src.record_id == child.getId()][0]
203 self.assertEqual(source.center[1], px)
204 self.assertEqual(source.center[0], py)
206 if useFlux:
207 # Get the flux re-weighted model and test against
208 # the HeavyFootprint.
209 # The HeavyFootprint needs to be projected onto
210 # the image of the flux-redistributed model,
211 # since the HeavyFootprint may trim rows or columns.
212 parentFootprint = catalog[catalog["id"] == child["parent"]][0].getFootprint()
213 _images = imageForRedistribution[parentFootprint.getBBox()].image.array
214 blend.observation.images = scl.Image(
215 _images[None, :, :],
216 yx0=blendData.origin,
217 bands=("dummy", ),
218 )
219 blend.observation.weights = scl.Image(
220 parentFootprint.spans.asArray()[None, :, :],
221 yx0=blendData.origin,
222 bands=("dummy", ),
223 )
224 blend.conserve_flux()
225 model = source.flux_weighted_image.data[0]
226 bbox = scarletBoxToBBox(source.flux_weighted_image.bbox)
227 image = afwImage.ImageF(model, xy0=bbox.getMin())
228 fp.insert(image)
229 np.testing.assert_almost_equal(image.array, model)
230 else:
231 # Get the model for the source and test
232 # against the HeavyFootprint
233 bbox = fp.getBBox()
234 bbox = bboxToScarletBox(bbox)
235 model = blend.observation.convolve(
236 source.get_model().project(bbox=bbox), mode="real"
237 ).data[0]
238 np.testing.assert_almost_equal(img.array, model)
240 # Check that all sources have the correct number of peaks
241 for src in catalog:
242 fp = src.getFootprint()
243 self.assertEqual(len(fp.peaks), src.get("deblend_nPeaks"))
245 # Check that only the large footprint was flagged as too big
246 largeFootprint = np.zeros(len(catalog), dtype=bool)
247 largeFootprint[2] = True
248 np.testing.assert_array_equal(largeFootprint, catalog["deblend_parentTooBig"])
250 # Check that only the dense footprint was flagged as too dense
251 denseFootprint = np.zeros(len(catalog), dtype=bool)
252 denseFootprint[3] = True
253 np.testing.assert_array_equal(denseFootprint, catalog["deblend_tooManyPeaks"])
255 # Check that only the appropriate parents were skipped
256 skipped = largeFootprint | denseFootprint
257 np.testing.assert_array_equal(skipped, catalog["deblend_skipped"])
259 def test_continuity(self):
260 """This test ensures that lsst.scarlet.lite gives roughly the same
261 result as scarlet.lite
263 TODO: This test can be removed once the deprecated scarlet.lite
264 module is removed from the science pipelines.
265 """
266 oldCatalog, oldModelData, oldConfig = self._deblend("old_lite")
267 catalog, modelData, config = self._deblend("lite")
269 # Ensure that the deblender used different versions
270 self.assertEqual(oldConfig.version, "old_lite")
271 self.assertEqual(config.version, "lite")
273 # Check that the PSF and other properties are the same
274 assert_almost_equal(oldModelData.psf, modelData.psf)
275 self.assertTupleEqual(tuple(oldModelData.blends.keys()), tuple(modelData.blends.keys()))
277 # Make sure that the sources have the same IDs
278 for i in range(len(catalog)):
279 self.assertEqual(catalog[i]["id"], oldCatalog[i]["id"])
281 for blendId in modelData.blends.keys():
282 oldBlendData = oldModelData.blends[blendId]
283 blendData = modelData.blends[blendId]
285 # Check that blend properties are the same
286 self.assertTupleEqual(oldBlendData.origin, blendData.origin)
287 self.assertTupleEqual(oldBlendData.shape, blendData.shape)
288 self.assertTupleEqual(oldBlendData.bands, blendData.bands)
289 self.assertTupleEqual(oldBlendData.psf_center, blendData.psf_center)
290 self.assertTupleEqual(tuple(oldBlendData.sources.keys()), tuple(blendData.sources.keys()))
291 assert_almost_equal(oldBlendData.psf, blendData.psf)
293 for sourceId in blendData.sources.keys():
294 oldSourceData = oldBlendData.sources[sourceId]
295 sourceData = blendData.sources[sourceId]
296 # Check that source properties are the same
297 self.assertEqual(len(oldSourceData.components), 0)
298 self.assertEqual(len(sourceData.components), 0)
299 self.assertEqual(
300 len(oldSourceData.factorized_components),
301 len(sourceData.factorized_components)
302 )
304 for c in range(len(sourceData.factorized_components)):
305 oldComponentData = oldSourceData.factorized_components[c]
306 componentData = sourceData.factorized_components[c]
307 # Check that component properties are the same
308 self.assertTupleEqual(oldComponentData.peak, componentData.peak)
309 self.assertTupleEqual(
310 tuple(oldComponentData.peak[i]-oldComponentData.shape[i]//2 for i in range(2)),
311 oldComponentData.origin,
312 )
315class MemoryTester(lsst.utils.tests.MemoryTestCase):
316 pass
319def setup_module(module):
320 lsst.utils.tests.init()
323if __name__ == "__main__": 323 ↛ 324line 323 didn't jump to line 324, because the condition on line 323 was never true
324 lsst.utils.tests.init()
325 unittest.main()