Coverage for tests/test_deblend.py: 13%
160 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-19 10:49 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-19 10:49 +0000
1# This file is part of meas_extensions_scarlet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import unittest
24import numpy as np
25from numpy.testing import assert_almost_equal
27from lsst.geom import Point2I, Point2D
28import lsst.utils.tests
29import lsst.afw.image as afwImage
30from lsst.meas.algorithms import SourceDetectionTask
31from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask
32from lsst.meas.extensions.scarlet.utils import bboxToScarletBox, scarletBoxToBBox
33from lsst.meas.extensions.scarlet.io import monochromaticDataToScarlet, updateCatalogFootprints
34import lsst.scarlet.lite as scl
35from lsst.afw.table import SourceCatalog
36from lsst.afw.detection import Footprint
37from lsst.afw.geom import SpanSet, Stencil
39from utils import initData
42class TestDeblend(lsst.utils.tests.TestCase):
43 def setUp(self):
44 # Set the random seed so that the noise field is unaffected
45 np.random.seed(0)
46 shape = (5, 100, 115)
47 coords = [
48 # blend
49 (15, 25), (10, 30), (17, 38),
50 # isolated source
51 (85, 90),
52 ]
53 amplitudes = [
54 # blend
55 80, 60, 90,
56 # isolated source
57 20,
58 ]
59 result = initData(shape, coords, amplitudes)
60 targetPsfImage, psfImages, images, channels, seds, morphs, psfs = result
61 B, Ny, Nx = shape
63 # Add some noise, otherwise the task will blow up due to
64 # zero variance
65 noise = 10*(np.random.rand(*images.shape).astype(np.float32)-.5)
66 images += noise
68 self.bands = "grizy"
69 _images = afwImage.MultibandMaskedImage.fromArrays(
70 self.bands,
71 images.astype(np.float32),
72 None,
73 noise**2
74 )
75 coadds = [afwImage.Exposure(img, dtype=img.image.array.dtype) for img in _images]
76 self.coadds = afwImage.MultibandExposure.fromExposures(self.bands, coadds)
77 for b, coadd in enumerate(self.coadds):
78 coadd.setPsf(psfs[b])
80 def _deblend(self, version):
81 schema = SourceCatalog.Table.makeMinimalSchema()
82 # Adjust config options to test skipping parents
83 config = ScarletDeblendTask.ConfigClass()
84 config.maxIter = 100
85 config.maxFootprintArea = 1000
86 config.maxNumberOfPeaks = 4
87 config.catchFailures = False
88 config.version = version
90 # Detect sources
91 detectionTask = SourceDetectionTask(schema=schema)
92 deblendTask = ScarletDeblendTask(schema=schema, config=config)
93 table = SourceCatalog.Table.make(schema)
94 detectionResult = detectionTask.run(table, self.coadds["r"])
95 catalog = detectionResult.sources
97 # Add a footprint that is too large
98 src = catalog.addNew()
99 halfLength = int(np.ceil(np.sqrt(config.maxFootprintArea) + 1))
100 ss = SpanSet.fromShape(halfLength, Stencil.BOX, offset=(50, 50))
101 bigfoot = Footprint(ss)
102 bigfoot.addPeak(50, 50, 100)
103 src.setFootprint(bigfoot)
105 # Add a footprint with too many peaks
106 src = catalog.addNew()
107 ss = SpanSet.fromShape(10, Stencil.BOX, offset=(75, 20))
108 denseFoot = Footprint(ss)
109 for n in range(config.maxNumberOfPeaks+1):
110 denseFoot.addPeak(70+2*n, 15+2*n, 10*n)
111 src.setFootprint(denseFoot)
113 # Run the deblender
114 catalog, modelData = deblendTask.run(self.coadds, catalog)
115 return catalog, modelData, config
117 def test_deblend_task(self):
118 catalog, modelData, config = self._deblend("lite")
120 # Attach the footprints in each band and compare to the full
121 # data model. This is done in each band, both with and without
122 # flux re-distribution to test all of the different possible
123 # options of loading catalog footprints.
124 for useFlux in [False, True]:
125 for band in self.bands:
126 bandIndex = self.bands.index(band)
127 coadd = self.coadds[band]
129 if useFlux:
130 imageForRedistribution = coadd
131 else:
132 imageForRedistribution = None
134 updateCatalogFootprints(
135 modelData,
136 catalog,
137 band=band,
138 imageForRedistribution=imageForRedistribution,
139 removeScarletData=False,
140 )
142 # Check that the number of deblended children is consistent
143 parents = catalog[catalog["parent"] == 0]
144 self.assertEqual(np.sum(catalog["deblend_nChild"]), len(catalog)-len(parents))
146 # Check that the models have not been cleared
147 # from the modelData
148 self.assertEqual(len(modelData.blends), np.sum(~parents["deblend_skipped"]))
150 for parent in parents:
151 children = catalog[catalog["parent"] == parent.get("id")]
152 # Check that nChild is set correctly
153 self.assertEqual(len(children), parent.get("deblend_nChild"))
154 # Check that parent columns are propagated
155 # to their children
156 for parentCol, childCol in config.columnInheritance.items():
157 np.testing.assert_array_equal(parent.get(parentCol), children[childCol])
159 children = catalog[catalog["parent"] != 0]
160 for child in children:
161 fp = child.getFootprint()
162 img = fp.extractImage(fill=0.0)
163 # Check that the flux at the center is correct.
164 # Note: this only works in this test image because the
165 # detected peak is in the same location as the
166 # scarlet peak.
167 # If the peak is shifted, the flux value will be correct
168 # but deblend_peak_center is not the correct location.
169 px = child.get("deblend_peak_center_x")
170 py = child.get("deblend_peak_center_y")
171 flux = img[Point2I(px, py)]
172 self.assertEqual(flux, child.get("deblend_peak_instFlux"))
174 # Check that the peak positions match the catalog entry
175 peaks = fp.getPeaks()
176 self.assertEqual(px, peaks[0].getIx())
177 self.assertEqual(py, peaks[0].getIy())
179 # Load the data to check against the HeavyFootprint
180 blendData = modelData.blends[child["parent"]]
181 # We need to set an observation in order to convolve
182 # the model.
183 position = Point2D(*blendData.psf_center[::-1])
184 _psfs = self.coadds[band].getPsf().computeKernelImage(position).array[None, :, :]
185 modelBox = scl.Box(blendData.shape, origin=blendData.origin)
186 observation = scl.Observation.empty(
187 bands=("dummy", ),
188 psfs=_psfs,
189 model_psf=modelData.psf[None, :, :],
190 bbox=modelBox,
191 dtype=np.float32,
192 )
193 blend = monochromaticDataToScarlet(
194 blendData=blendData,
195 bandIndex=bandIndex,
196 observation=observation,
197 )
198 # The stored PSF should be the same as the calculated one
199 assert_almost_equal(blendData.psf[bandIndex:bandIndex+1], _psfs)
201 # Get the scarlet model for the source
202 source = [src for src in blend.sources if src.record_id == child.getId()][0]
203 self.assertEqual(source.center[1], px)
204 self.assertEqual(source.center[0], py)
206 if useFlux:
207 # Get the flux re-weighted model and test against
208 # the HeavyFootprint.
209 # The HeavyFootprint needs to be projected onto
210 # the image of the flux-redistributed model,
211 # since the HeavyFootprint may trim rows or columns.
212 parentFootprint = catalog[catalog["id"] == child["parent"]][0].getFootprint()
213 _images = imageForRedistribution[parentFootprint.getBBox()].image.array
214 blend.observation.images = scl.Image(
215 _images[None, :, :],
216 yx0=blendData.origin,
217 bands=("dummy", ),
218 )
219 blend.observation.weights = scl.Image(
220 parentFootprint.spans.asArray()[None, :, :],
221 yx0=blendData.origin,
222 bands=("dummy", ),
223 )
224 blend.conserve_flux()
225 model = source.flux_weighted_image.data[0]
226 bbox = scarletBoxToBBox(source.flux_weighted_image.bbox)
227 image = afwImage.ImageF(model, xy0=bbox.getMin())
228 fp.insert(image)
229 np.testing.assert_almost_equal(image.array, model)
230 else:
231 # Get the model for the source and test
232 # against the HeavyFootprint
233 bbox = fp.getBBox()
234 bbox = bboxToScarletBox(bbox)
235 model = blend.observation.convolve(source.get_model().project(bbox=bbox)).data[0]
236 np.testing.assert_almost_equal(img.array, model)
238 # Check that all sources have the correct number of peaks
239 for src in catalog:
240 fp = src.getFootprint()
241 self.assertEqual(len(fp.peaks), src.get("deblend_nPeaks"))
243 # Check that only the large footprint was flagged as too big
244 largeFootprint = np.zeros(len(catalog), dtype=bool)
245 largeFootprint[2] = True
246 np.testing.assert_array_equal(largeFootprint, catalog["deblend_parentTooBig"])
248 # Check that only the dense footprint was flagged as too dense
249 denseFootprint = np.zeros(len(catalog), dtype=bool)
250 denseFootprint[3] = True
251 np.testing.assert_array_equal(denseFootprint, catalog["deblend_tooManyPeaks"])
253 # Check that only the appropriate parents were skipped
254 skipped = largeFootprint | denseFootprint
255 np.testing.assert_array_equal(skipped, catalog["deblend_skipped"])
257 def test_continuity(self):
258 """This test ensures that lsst.scarlet.lite gives roughly the same
259 result as scarlet.lite
261 TODO: This test can be removed once the deprecated scarlet.lite
262 module is removed from the science pipelines.
263 """
264 oldCatalog, oldModelData, oldConfig = self._deblend("old_lite")
265 catalog, modelData, config = self._deblend("lite")
267 # Ensure that the deblender used different versions
268 self.assertEqual(oldConfig.version, "old_lite")
269 self.assertEqual(config.version, "lite")
271 # Check that the PSF and other properties are the same
272 assert_almost_equal(oldModelData.psf, modelData.psf)
273 self.assertTupleEqual(tuple(oldModelData.blends.keys()), tuple(modelData.blends.keys()))
275 # Make sure that the sources have the same IDs
276 for i in range(len(catalog)):
277 self.assertEqual(catalog[i]["id"], oldCatalog[i]["id"])
279 for blendId in modelData.blends.keys():
280 oldBlendData = oldModelData.blends[blendId]
281 blendData = modelData.blends[blendId]
283 # Check that blend properties are the same
284 self.assertTupleEqual(oldBlendData.origin, blendData.origin)
285 self.assertTupleEqual(oldBlendData.shape, blendData.shape)
286 self.assertTupleEqual(oldBlendData.bands, blendData.bands)
287 self.assertTupleEqual(oldBlendData.psf_center, blendData.psf_center)
288 self.assertTupleEqual(tuple(oldBlendData.sources.keys()), tuple(blendData.sources.keys()))
289 assert_almost_equal(oldBlendData.psf, blendData.psf)
291 for sourceId in blendData.sources.keys():
292 oldSourceData = oldBlendData.sources[sourceId]
293 sourceData = blendData.sources[sourceId]
294 # Check that source properties are the same
295 self.assertEqual(len(oldSourceData.components), 0)
296 self.assertEqual(len(sourceData.components), 0)
297 self.assertEqual(
298 len(oldSourceData.factorized_components),
299 len(sourceData.factorized_components)
300 )
302 for c in range(len(sourceData.factorized_components)):
303 oldComponentData = oldSourceData.factorized_components[c]
304 componentData = sourceData.factorized_components[c]
305 # Check that component properties are the same
306 self.assertTupleEqual(oldComponentData.peak, componentData.peak)
307 self.assertTupleEqual(
308 tuple(oldComponentData.peak[i]-oldComponentData.shape[i]//2 for i in range(2)),
309 oldComponentData.origin,
310 )
313class MemoryTester(lsst.utils.tests.MemoryTestCase):
314 pass
317def setup_module(module):
318 lsst.utils.tests.init()
321if __name__ == "__main__": 321 ↛ 322line 321 didn't jump to line 322, because the condition on line 321 was never true
322 lsst.utils.tests.init()
323 unittest.main()