Coverage for tests/test_deblend.py: 11%
188 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-11 11:47 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-11 11:47 +0000
1# This file is part of meas_extensions_scarlet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import unittest
24import numpy as np
25from numpy.testing import assert_almost_equal
27from lsst.geom import Point2I, Point2D
28import lsst.utils.tests
29import lsst.afw.image as afwImage
30from lsst.meas.algorithms import SourceDetectionTask
31from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask
32from lsst.meas.extensions.scarlet.utils import bboxToScarletBox, scarletBoxToBBox
33from lsst.meas.extensions.scarlet.io import monochromaticDataToScarlet, updateCatalogFootprints
34import lsst.scarlet.lite as scl
35from lsst.afw.table import SourceCatalog
36from lsst.afw.detection import Footprint
37from lsst.afw.geom import SpanSet, Stencil
39from utils import initData
42class TestDeblend(lsst.utils.tests.TestCase):
43 def setUp(self):
44 # Set the random seed so that the noise field is unaffected
45 np.random.seed(0)
46 shape = (5, 100, 115)
47 coords = [
48 # blend
49 (15, 25), (10, 30), (17, 38),
50 # isolated source
51 (85, 90),
52 ]
53 amplitudes = [
54 # blend
55 80, 60, 90,
56 # isolated source
57 20,
58 ]
59 result = initData(shape, coords, amplitudes)
60 targetPsfImage, psfImages, images, channels, seds, morphs, psfs = result
61 B, Ny, Nx = shape
63 # Add some noise, otherwise the task will blow up due to
64 # zero variance
65 noise = 10*(np.random.rand(*images.shape).astype(np.float32)-.5)
66 images += noise
68 self.bands = "grizy"
69 _images = afwImage.MultibandMaskedImage.fromArrays(
70 self.bands,
71 images.astype(np.float32),
72 None,
73 noise**2
74 )
75 coadds = [afwImage.Exposure(img, dtype=img.image.array.dtype) for img in _images]
76 self.coadds = afwImage.MultibandExposure.fromExposures(self.bands, coadds)
77 for b, coadd in enumerate(self.coadds):
78 coadd.setPsf(psfs[b])
80 def _insert_blank_source(self, modelData, catalog):
81 # Add parent
82 parent = catalog.addNew()
83 parent.setParent(0)
84 parent["deblend_nChild"] = 1
85 parent["deblend_nPeaks"] = 1
86 ss = SpanSet.fromShape(5, Stencil.CIRCLE, offset=(30, 70))
87 footprint = Footprint(ss)
88 peak = footprint.addPeak(30, 70, 0)
89 parent.setFootprint(footprint)
91 # Add the zero flux source
92 dtype = np.float32
93 center = (70, 30)
94 origin = (center[0]-5, center[1]-5)
95 psf = list(modelData.blends.values())[0].psf
96 src = catalog.addNew()
97 src.setParent(parent.getId())
98 src["deblend_peak_center_x"] = center[1]
99 src["deblend_peak_center_y"] = center[0]
100 src["deblend_nPeaks"] = 1
102 sources = {
103 src.getId(): {
104 "components": [],
105 "factorized": [{
106 "origin": origin,
107 "peak": center,
108 "spectrum": np.zeros((len(self.bands),), dtype=dtype),
109 "morph": np.zeros((11, 11), dtype=dtype),
110 "shape": (11, 11),
111 }],
112 "peak_id": peak.getId(),
113 }
114 }
116 blendData = scl.io.ScarletBlendData.from_dict({
117 "origin": origin,
118 "shape": (11, 11),
119 "psf_center": center,
120 "psf_shape": psf.shape,
121 "psf": psf.flatten(),
122 "sources": sources,
123 "bands": self.bands,
124 })
125 pid = parent.getId()
126 modelData.blends[pid] = blendData
127 return pid, src.getId()
129 def _deblend(self, version):
130 schema = SourceCatalog.Table.makeMinimalSchema()
131 # Adjust config options to test skipping parents
132 config = ScarletDeblendTask.ConfigClass()
133 config.maxIter = 100
134 config.maxFootprintArea = 1000
135 config.maxNumberOfPeaks = 4
136 config.catchFailures = False
137 config.version = version
139 # Detect sources
140 detectionTask = SourceDetectionTask(schema=schema)
141 deblendTask = ScarletDeblendTask(schema=schema, config=config)
142 table = SourceCatalog.Table.make(schema)
143 detectionResult = detectionTask.run(table, self.coadds["r"])
144 catalog = detectionResult.sources
146 # Add a footprint that is too large
147 src = catalog.addNew()
148 halfLength = int(np.ceil(np.sqrt(config.maxFootprintArea) + 1))
149 ss = SpanSet.fromShape(halfLength, Stencil.BOX, offset=(50, 50))
150 bigfoot = Footprint(ss)
151 bigfoot.addPeak(50, 50, 100)
152 src.setFootprint(bigfoot)
154 # Add a footprint with too many peaks
155 src = catalog.addNew()
156 ss = SpanSet.fromShape(10, Stencil.BOX, offset=(75, 20))
157 denseFoot = Footprint(ss)
158 for n in range(config.maxNumberOfPeaks+1):
159 denseFoot.addPeak(70+2*n, 15+2*n, 10*n)
160 src.setFootprint(denseFoot)
162 # Run the deblender
163 catalog, modelData = deblendTask.run(self.coadds, catalog)
164 return catalog, modelData, config
166 def test_deblend_task(self):
167 catalog, modelData, config = self._deblend("lite")
169 bad_blend_id, bad_src_id = self._insert_blank_source(modelData, catalog)
171 # Attach the footprints in each band and compare to the full
172 # data model. This is done in each band, both with and without
173 # flux re-distribution to test all of the different possible
174 # options of loading catalog footprints.
175 for useFlux in [False, True]:
176 for band in self.bands:
177 bandIndex = self.bands.index(band)
178 coadd = self.coadds[band]
180 if useFlux:
181 imageForRedistribution = coadd
182 else:
183 imageForRedistribution = None
185 updateCatalogFootprints(
186 modelData,
187 catalog,
188 band=band,
189 imageForRedistribution=imageForRedistribution,
190 removeScarletData=False,
191 )
193 # Check that the number of deblended children is consistent
194 parents = catalog[catalog["parent"] == 0]
195 self.assertEqual(np.sum(catalog["deblend_nChild"]), len(catalog)-len(parents))
197 # Check that the models have not been cleared
198 # from the modelData
199 self.assertEqual(len(modelData.blends), np.sum(~parents["deblend_skipped"]))
201 for parent in parents:
202 children = catalog[catalog["parent"] == parent.get("id")]
203 # Check that nChild is set correctly
204 self.assertEqual(len(children), parent.get("deblend_nChild"))
205 # Check that parent columns are propagated
206 # to their children
207 if parent.getId() == bad_blend_id:
208 continue
209 for parentCol, childCol in config.columnInheritance.items():
210 np.testing.assert_array_equal(parent.get(parentCol), children[childCol])
212 children = catalog[catalog["parent"] != 0]
213 for child in children:
214 fp = child.getFootprint()
215 img = fp.extractImage(fill=0.0)
216 # Check that the flux at the center is correct.
217 # Note: this only works in this test image because the
218 # detected peak is in the same location as the
219 # scarlet peak.
220 # If the peak is shifted, the flux value will be correct
221 # but deblend_peak_center is not the correct location.
222 px = child.get("deblend_peak_center_x")
223 py = child.get("deblend_peak_center_y")
224 flux = img[Point2I(px, py)]
225 self.assertEqual(flux, child.get("deblend_peak_instFlux"))
227 # Check that the peak positions match the catalog entry
228 peaks = fp.getPeaks()
229 self.assertEqual(px, peaks[0].getIx())
230 self.assertEqual(py, peaks[0].getIy())
232 # Load the data to check against the HeavyFootprint
233 blendData = modelData.blends[child["parent"]]
234 # We need to set an observation in order to convolve
235 # the model.
236 position = Point2D(*blendData.psf_center[::-1])
237 _psfs = self.coadds[band].getPsf().computeKernelImage(position).array[None, :, :]
238 modelBox = scl.Box(blendData.shape, origin=blendData.origin)
239 observation = scl.Observation.empty(
240 bands=("dummy", ),
241 psfs=_psfs,
242 model_psf=modelData.psf[None, :, :],
243 bbox=modelBox,
244 dtype=np.float32,
245 )
246 blend = monochromaticDataToScarlet(
247 blendData=blendData,
248 bandIndex=bandIndex,
249 observation=observation,
250 )
251 # The stored PSF should be the same as the calculated one
252 assert_almost_equal(blendData.psf[bandIndex:bandIndex+1], _psfs)
254 # Get the scarlet model for the source
255 source = [src for src in blend.sources if src.record_id == child.getId()][0]
256 self.assertEqual(source.center[1], px)
257 self.assertEqual(source.center[0], py)
259 if useFlux:
260 # Get the flux re-weighted model and test against
261 # the HeavyFootprint.
262 # The HeavyFootprint needs to be projected onto
263 # the image of the flux-redistributed model,
264 # since the HeavyFootprint may trim rows or columns.
265 parentFootprint = catalog[catalog["id"] == child["parent"]][0].getFootprint()
266 _images = imageForRedistribution[parentFootprint.getBBox()].image.array
267 blend.observation.images = scl.Image(
268 _images[None, :, :],
269 yx0=blendData.origin,
270 bands=("dummy", ),
271 )
272 blend.observation.weights = scl.Image(
273 parentFootprint.spans.asArray()[None, :, :],
274 yx0=blendData.origin,
275 bands=("dummy", ),
276 )
277 blend.conserve_flux()
278 model = source.flux_weighted_image.data[0]
279 bbox = scarletBoxToBBox(source.flux_weighted_image.bbox)
280 image = afwImage.ImageF(model, xy0=bbox.getMin())
281 fp.insert(image)
282 np.testing.assert_almost_equal(image.array, model)
283 else:
284 # Get the model for the source and test
285 # against the HeavyFootprint
286 bbox = fp.getBBox()
287 bbox = bboxToScarletBox(bbox)
288 model = blend.observation.convolve(
289 source.get_model().project(bbox=bbox), mode="real"
290 ).data[0]
291 np.testing.assert_almost_equal(img.array, model)
293 # Check that all sources have the correct number of peaks
294 for src in catalog:
295 fp = src.getFootprint()
296 self.assertEqual(len(fp.peaks), src.get("deblend_nPeaks"))
298 # Check that only the large footprint was flagged as too big
299 largeFootprint = np.zeros(len(catalog), dtype=bool)
300 largeFootprint[2] = True
301 np.testing.assert_array_equal(largeFootprint, catalog["deblend_parentTooBig"])
303 # Check that only the dense footprint was flagged as too dense
304 denseFootprint = np.zeros(len(catalog), dtype=bool)
305 denseFootprint[3] = True
306 np.testing.assert_array_equal(denseFootprint, catalog["deblend_tooManyPeaks"])
308 # Check that only the appropriate parents were skipped
309 skipped = largeFootprint | denseFootprint
310 np.testing.assert_array_equal(skipped, catalog["deblend_skipped"])
312 # Check that the zero flux source was flagged
313 for src in catalog:
314 np.testing.assert_equal(src["deblend_zeroFlux"], src.getId() == bad_src_id)
316 def test_continuity(self):
317 """This test ensures that lsst.scarlet.lite gives roughly the same
318 result as scarlet.lite
320 TODO: This test can be removed once the deprecated scarlet.lite
321 module is removed from the science pipelines.
322 """
323 oldCatalog, oldModelData, oldConfig = self._deblend("old_lite")
324 catalog, modelData, config = self._deblend("lite")
326 # Ensure that the deblender used different versions
327 self.assertEqual(oldConfig.version, "old_lite")
328 self.assertEqual(config.version, "lite")
330 # Check that the PSF and other properties are the same
331 assert_almost_equal(oldModelData.psf, modelData.psf)
332 self.assertTupleEqual(tuple(oldModelData.blends.keys()), tuple(modelData.blends.keys()))
334 # Make sure that the sources have the same IDs
335 for i in range(len(catalog)):
336 self.assertEqual(catalog[i]["id"], oldCatalog[i]["id"])
338 for blendId in modelData.blends.keys():
339 oldBlendData = oldModelData.blends[blendId]
340 blendData = modelData.blends[blendId]
342 # Check that blend properties are the same
343 self.assertTupleEqual(oldBlendData.origin, blendData.origin)
344 self.assertTupleEqual(oldBlendData.shape, blendData.shape)
345 self.assertTupleEqual(oldBlendData.bands, blendData.bands)
346 self.assertTupleEqual(oldBlendData.psf_center, blendData.psf_center)
347 self.assertTupleEqual(tuple(oldBlendData.sources.keys()), tuple(blendData.sources.keys()))
348 assert_almost_equal(oldBlendData.psf, blendData.psf)
350 for sourceId in blendData.sources.keys():
351 oldSourceData = oldBlendData.sources[sourceId]
352 sourceData = blendData.sources[sourceId]
353 # Check that source properties are the same
354 self.assertEqual(len(oldSourceData.components), 0)
355 self.assertEqual(len(sourceData.components), 0)
356 self.assertEqual(
357 len(oldSourceData.factorized_components),
358 len(sourceData.factorized_components)
359 )
361 for c in range(len(sourceData.factorized_components)):
362 oldComponentData = oldSourceData.factorized_components[c]
363 componentData = sourceData.factorized_components[c]
364 # Check that component properties are the same
365 self.assertTupleEqual(oldComponentData.peak, componentData.peak)
366 self.assertTupleEqual(
367 tuple(oldComponentData.peak[i]-oldComponentData.shape[i]//2 for i in range(2)),
368 oldComponentData.origin,
369 )
372class MemoryTester(lsst.utils.tests.MemoryTestCase):
373 pass
376def setup_module(module):
377 lsst.utils.tests.init()
380if __name__ == "__main__": 380 ↛ 381line 380 didn't jump to line 381, because the condition on line 380 was never true
381 lsst.utils.tests.init()
382 unittest.main()