Coverage for tests / test_deblend.py: 12%
423 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 09:00 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 09:00 +0000
1# This file is part of meas_extensions_scarlet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
23import unittest
24import tempfile
26import lsst.afw.image as afwImage
27import lsst.meas.extensions.scarlet as mes
28import lsst.scarlet.lite as scl
29import lsst.utils.tests
30import numpy as np
31from lsst.afw.detection import Footprint, GaussianPsf, InvalidPsfError, PeakTable, Psf
32from lsst.afw.geom import SpanSet
33from lsst.afw.table import SourceCatalog, SourceTable, SchemaMapper
34import lsst.daf.butler
35from lsst.daf.butler import Butler, Config, DatasetType, StorageClass, FileDataset, DatasetRef
36from lsst.daf.butler.tests import makeTestRepo, makeTestCollection
37from lsst.geom import Extent2I, Point2D, Point2I
38from lsst.meas.algorithms import SourceDetectionTask
39from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask
40from lsst.meas.extensions.scarlet.deconvolveExposureTask import DeconvolveExposureTask
41from lsst.pipe.base import NoWorkFound, Struct
42from utils import initData, SersicModel, PsfModel
44TESTDIR = os.path.abspath(os.path.dirname(__file__))
47class TestDeblend(lsst.utils.tests.TestCase):
48 def setUp(self):
49 self.modelPsf = scl.utils.integrated_circular_gaussian(sigma=0.8).astype(np.float32)
50 psfRadius = 20
51 psfShape = (2 * psfRadius + 1, 2 * psfRadius + 1)
52 self.psfs = [
53 GaussianPsf(psfShape[1], psfShape[0], 1.0),
54 GaussianPsf(psfShape[1], psfShape[0], 1.2),
55 GaussianPsf(psfShape[1], psfShape[0], 1.4),
56 ]
57 self.imagePsf = np.asarray(
58 [psf.computeImage(psf.getAveragePosition()).array for psf in self.psfs]
59 ).astype(np.float32)
60 self.imagePsf /= self.imagePsf.sum(axis=(1, 2))[:, None, None]
61 self.bands = tuple("gri")
63 self.models = [
64 # Isolated source
65 PsfModel(
66 center=(30, 15),
67 spectrum=np.array([8, 2, 1]),
68 bands=self.bands,
69 ),
70 # Two source blend
71 SersicModel(
72 center=(40, 20),
73 major=5,
74 minor=2,
75 radius=15,
76 theta=-np.pi/4,
77 n=1,
78 spectrum=np.array([2, 4, 8]),
79 bands=self.bands,
80 ),
81 PsfModel(
82 center=(12, 20),
83 spectrum=np.array([1, 2, 8]),
84 bands=self.bands,
85 ),
86 # 3 source blend
87 SersicModel(
88 center=(25, 70),
89 major=5,
90 minor=2,
91 radius=20,
92 theta=np.pi/48,
93 n=1,
94 spectrum=np.array([2, 5, 8]),
95 bands=self.bands,
96 ),
97 PsfModel(
98 center=(32, 60),
99 spectrum=np.array([1, 2, 8]),
100 bands=self.bands,
101 ),
102 PsfModel(
103 center=(16, 80),
104 spectrum=np.array([8, 2, 1]),
105 bands=self.bands,
106 ),
107 # Large blend
108 SersicModel(
109 center=(70, 70),
110 major=5,
111 minor=2,
112 radius=25,
113 theta=0,
114 n=1,
115 spectrum=np.array([2, 10, 18]),
116 bands=self.bands,
117 ),
118 SersicModel(
119 center=(85, 85),
120 major=5,
121 minor=2,
122 radius=25,
123 theta=np.pi/2,
124 n=1,
125 spectrum=np.array([5, 10, 20]),
126 bands=self.bands,
127 ),
128 ]
130 def scarlet_image_to_exposure(
131 self,
132 image: scl.Image,
133 noise: np.ndarray,
134 ) -> afwImage.MultibandExposure:
135 masked_image = afwImage.MultibandMaskedImage.fromArrays(
136 image.bands, image.data, None, noise**2
137 )
138 coadds = [
139 afwImage.Exposure(img, dtype=img.image.array.dtype) for img in masked_image
140 ]
141 mCoadd = afwImage.MultibandExposure.fromExposures(image.bands, coadds)
142 for b, coadd in enumerate(mCoadd):
143 coadd.setPsf(self.psfs[b])
144 return mCoadd
146 def initialize_data(
147 self,
148 models,
149 deconvolveConfig=None,
150 deblendConfig=None,
151 doDetect: bool = True,
152 ):
153 if deconvolveConfig is None:
154 deconvolveConfig = DeconvolveExposureTask.ConfigClass()
155 if deblendConfig is None:
156 deblendConfig = ScarletDeblendTask.ConfigClass()
157 # Generate the data for the test
158 deconvolved, convolved = initData(models, self.modelPsf, self.imagePsf)
159 # Set the random seed so that the noise field is unaffected
160 # and add noise to the image
161 np.random.seed(0)
162 noise = 0.05 * (np.random.rand(*convolved.shape).astype(np.float32) - 0.5)
163 noisyImage = convolved.copy()
164 noisyImage._data += noise
165 # Create the multiband coadd
166 mCoadd = self.scarlet_image_to_exposure(noisyImage, noise)
167 # Initialze tasks
168 inputSchema = SourceTable.makeMinimalSchema()
169 table = SourceTable.make(inputSchema)
170 detectionTask = SourceDetectionTask(schema=inputSchema)
171 schemaMapper = SchemaMapper(inputSchema)
172 schemaMapper.addMinimalSchema(inputSchema)
173 schema = schemaMapper.getOutputSchema()
174 deconvolveTask = DeconvolveExposureTask(config=deconvolveConfig)
175 deblendTask = ScarletDeblendTask(schema=schema, config=deblendConfig)
177 result = Struct(
178 deconvolved=deconvolved,
179 convolved=convolved,
180 noise=noise,
181 noisyImage=noisyImage,
182 mCoadd=mCoadd,
183 detectionTask=detectionTask,
184 deconvolveTask=deconvolveTask,
185 deblendTask=deblendTask,
186 )
188 if doDetect:
189 # Generate a detection catalog
190 detectionResult = detectionTask.run(table, mCoadd["r"])
191 table = SourceCatalog.Table.make(schema)
192 catalog = SourceCatalog(table)
193 catalog.extend(detectionResult.sources, schemaMapper)
194 result.catalog = catalog
196 return result
198 def deconvolve(self, data: Struct):
199 deconvolvedCoadds = []
200 deconvolveTask = data.deconvolveTask
201 if deconvolveTask.config.useFootprints:
202 catalog = data.catalog
203 else:
204 catalog = None
205 for coadd in data.mCoadd:
206 deconvolvedCoadd = deconvolveTask.run(coadd, catalog).deconvolved
207 deconvolvedCoadds.append(deconvolvedCoadd)
208 mDeconvolved = afwImage.MultibandExposure.fromExposures(self.bands, deconvolvedCoadds)
209 return mDeconvolved
211 def test_default_deconvolve(self):
212 data = self.initialize_data(self.models)
213 deconvolved = self.deconvolve(data)
215 diff = data.deconvolved.data - deconvolved.image.array
216 # Due to peakiness of Sersic models the center has a sharp peak,
217 # so we ignore a 3x3 region around each source center
218 for model in self.models:
219 yc, xc = model.center
220 for x in (-1, 0, 1):
221 for y in (-1, 0, 1):
222 diff[:, yc+y, xc+x] = 0
223 self.assertTrue(np.max(diff[:2]) < 10*np.std(data.noise))
224 self.assertTrue(np.max(diff[2]) < 20*np.std(data.noise))
226 context = mes.scarletDeblendTask.ScarletDeblendContext.build(
227 data.mCoadd,
228 deconvolved,
229 data.catalog,
230 data.deblendTask.ConfigClass()
231 )
233 self.assertEqual(len(context.footprints), 4)
235 def test_catalog_free_deconvolve(self):
236 config = DeconvolveExposureTask.ConfigClass()
237 config.useFootprints = False
238 data = self.initialize_data(self.models, deconvolveConfig=config)
239 deconvolved = self.deconvolve(data)
241 diff = data.deconvolved.data - deconvolved.image.array
242 # Due to peakiness of Sersic models the center has a sharp peak,
243 # so we ignore a 3x3 region around each source center
244 for model in self.models:
245 yc, xc = model.center
246 for x in (-1, 0, 1):
247 for y in (-1, 0, 1):
248 diff[:, yc+y, xc+x] = 0
249 self.assertTrue(np.max(diff[:2]) < 10*np.std(data.noise))
250 self.assertTrue(np.max(diff[2]) < 20*np.std(data.noise))
252 def test_footprints(self):
253 data = self.initialize_data(self.models)
254 mDeconvolved = self.deconvolve(data)
255 result = data.deblendTask.run(data.mCoadd, mDeconvolved, data.catalog)
257 catalog = result.deblendedCatalog
258 objectParents = result.objectParents
259 modelData = result.scarletModelData
260 observedPsf = modelData.metadata["psf"]
261 modelPsf = modelData.metadata["model_psf"]
263 # Check that isolated sources are handled correctly
264 isolated = catalog[(catalog["parent"] == 0)]
265 self.assertEqual(len(isolated), len(modelData.isolated))
266 for sid, source in modelData.isolated.items():
267 catalog_footprint = catalog.find(sid).getFootprint()
268 isolated_array = catalog_footprint.spans.asArray()
269 np.testing.assert_array_equal(source.span_array, isolated_array)
271 # Check that the origin is correct
272 self.assertTupleEqual(source.origin[::-1], tuple(catalog_footprint.getBBox().getMin()))
274 # Verify that the isolated parent flag is being set
275 isolatedParents = objectParents[
276 (objectParents["parent"] == 0)
277 & (objectParents["deblend_nPeaks"] == 1)
278 ]
279 self.assertEqual(np.sum(objectParents["deblend_skipped_isolatedParent"]), len(isolatedParents))
281 # Attach the footprints in each band and compare to the full
282 # data model. This is done in each band, both with and without
283 # flux re-distribution to test all of the different possible
284 # options of loading catalog footprints.
285 for useFlux in [False, True]:
286 for band in self.bands:
287 bandIndex = self.bands.index(band)
288 coadd = data.mCoadd[band]
290 if useFlux:
291 imageForRedistribution = coadd
292 else:
293 imageForRedistribution = None
295 mes.io.updateCatalogFootprints(
296 modelData,
297 catalog,
298 band=band,
299 imageForRedistribution=imageForRedistribution,
300 removeScarletData=False,
301 updateFluxColumns=True,
302 )
304 # Check that the number of deblended children is consistent
305 parents = objectParents[
306 (objectParents["parent"] == 0) & (objectParents["deblend_nPeaks"] > 1)]
307 self.assertEqual(
308 np.sum(parents["deblend_nChild"]), len(catalog) - len(isolated)
309 )
311 for parent in parents:
312 children = catalog[catalog["parent"] == parent.get("id")]
314 # Extract the parent blend data
315 parentBlendData = modelData.blends[parent.getId()]
316 parentFootprint = parent.getFootprint()
317 x0, y0 = parentFootprint.getBBox().getMin()
318 width, height = parentFootprint.getBBox().getDimensions()
319 yx0 = (y0, x0)
321 for child in children:
322 fp = child.getFootprint()
323 img = fp.extractImage(fill=0.0)
324 # Check that the flux at the center is correct.
325 # Note: this only works in this test image because the
326 # detected peak is in the same location as the
327 # scarlet peak.
328 # If the peak is shifted,
329 # the flux value will be correct but
330 # deblend_peak_center is not the correct location.
331 px = child.get("deblend_peak_center_x")
332 py = child.get("deblend_peak_center_y")
333 flux = img[Point2I(px, py)]
334 self.assertEqual(flux, child.get("deblend_peak_instFlux"))
336 # Check that the peak positions match the catalog entry
337 peaks = fp.getPeaks()
338 self.assertEqual(px, peaks[0].getIx())
339 self.assertEqual(py, peaks[0].getIy())
341 # Load the data to check against the HeavyFootprint
342 blendData = parentBlendData.children[child["deblend_blendId"]]
343 # We need to set an observation in order to convolve
344 # the model.
345 modelBox = scl.Box((height, width), origin=(y0, x0))
346 observation = scl.Observation.empty(
347 bands=("dummy",),
348 psfs=observedPsf[bandIndex][None, :, :],
349 model_psf=modelPsf[None, :, :],
350 bbox=modelBox,
351 dtype=np.float32,
352 )
353 blend = mes.io.monochromaticDataToScarlet(
354 blendData=blendData,
355 bandIndex=bandIndex,
356 observation=observation,
357 )
359 # Get the scarlet model for the source
360 source = next(
361 src for src in blend.sources if src.metadata["id"] == child.getId()
362 )
363 self.assertEqual(source.center[1], px)
364 self.assertEqual(source.center[0], py)
366 if useFlux:
367 assert imageForRedistribution is not None
368 # Get the flux re-weighted model and test against
369 # the HeavyFootprint.
370 # The HeavyFootprint needs to be projected onto
371 # the image of the flux-redistributed model,
372 # since the HeavyFootprint
373 # may trim rows or columns.
374 _images = imageForRedistribution[
375 parentFootprint.getBBox()
376 ].image.array
377 blend.observation.images = scl.Image(
378 _images[None, :, :],
379 yx0=yx0,
380 bands=("dummy",),
381 )
382 blend.observation.weights = scl.Image(
383 parentFootprint.spans.asArray()[None, :, :],
384 yx0=yx0,
385 bands=("dummy",),
386 )
387 blend.conserve_flux()
388 model = source.flux_weighted_image.data[0]
389 my0, mx0 = source.flux_weighted_image.yx0
390 image = afwImage.ImageF(model, xy0=Point2I(mx0, my0))
391 fp.insert(image)
392 np.testing.assert_almost_equal(image.array, model)
393 else:
394 # Get the model for the source and test
395 # against the HeavyFootprint
396 bbox = fp.getBBox()
397 bbox = mes.utils.bboxToScarletBox(bbox)
398 model = blend.observation.convolve(
399 source.get_model().project(bbox=bbox), mode="real"
400 ).data[0]
401 np.testing.assert_almost_equal(img.array, model)
403 # Check that all sources have the correct number of peaks
404 maxId = np.max(objectParents["id"])
405 for src in catalog:
406 fp = src.getFootprint()
407 self.assertEqual(len(fp.peaks), 1)
408 if src["parent"] > 0:
409 # Check that source IDs are greater than the max parent ID
410 self.assertGreater(src["id"], maxId)
412 # Ensure that sources are sorted by parent ID
413 np.testing.assert_array_equal(sorted(catalog["parent"]), catalog["parent"])
415 # Check that the catalog matches the expected results
416 self.assertEqual(len(catalog), len(self.models))
418 def test_skipped(self):
419 # Use tight configs to force skipping a 3 source footprint
420 # and "large" footprint
421 config = ScarletDeblendTask.ConfigClass()
422 config.maxFootprintArea = 2000
423 config.maxNumberOfPeaks = 2
424 config.catchFailures = False
426 data = self.initialize_data(self.models, deblendConfig=config)
427 mDeconvolved = self.deconvolve(data)
428 result = data.deblendTask.run(data.mCoadd, mDeconvolved, data.catalog)
430 catalog = result.objectParents
431 parents = catalog[catalog["parent"] == 0]
432 self.assertEqual(np.sum(parents["deblend_skipped"]), 2)
433 self.assertEqual(np.sum(parents["deblend_skipped_parentTooBig"]), 1)
434 self.assertEqual(np.sum(parents["deblend_skipped_tooManyPeaks"]), 1)
436 def test_persistence(self):
437 # Test that the model data is persisted correctly
438 data = self.initialize_data(self.models)
439 repo = self._setup_butler()
440 mDeconvolved = self.deconvolve(data)
441 result = data.deblendTask.run(data.mCoadd, mDeconvolved, data.catalog)
442 modelData = result.scarletModelData
443 bands = modelData.metadata["bands"]
444 butler = makeTestCollection(repo, uniqueId="test_run1")
445 butler.put(modelData, "scarlet_model_data", dataId={})
446 modelData2 = butler.get("scarlet_model_data", dataId={})
447 model_psf = modelData.metadata["model_psf"][None, :, :]
448 model_psf2 = modelData2.metadata["model_psf"][None, :, :]
449 np.testing.assert_almost_equal(model_psf2, model_psf)
450 psf = modelData.metadata["psf"]
451 psf2 = modelData2.metadata["psf"]
452 np.testing.assert_almost_equal(psf2, psf)
453 self.assertEqual(len(modelData2.blends), len(modelData.blends))
455 for parentId in modelData.blends.keys():
456 nChildren = len(modelData.blends[parentId].children)
457 self.assertEqual(nChildren, len(modelData2.blends[parentId].children))
458 for blendId in modelData.blends[parentId].children:
459 blendData1 = modelData.blends[parentId].children[blendId]
460 blendData2 = modelData2.blends[parentId].children[blendId]
461 self._test_blend(blendData1, blendData2, model_psf, psf, bands)
463 for sourceId in modelData.isolated.keys():
464 isolatedData1 = modelData.isolated[sourceId]
465 isolatedData2 = modelData2.isolated[sourceId]
466 self.assertTupleEqual(isolatedData1.origin, isolatedData2.origin)
467 np.testing.assert_array_equal(
468 isolatedData1.span_array,
469 isolatedData2.span_array,
470 )
472 # Test extracting a single blend
473 modelData2 = butler.get("scarlet_model_data", dataId={}, parameters={"blend_id": parentId})
474 self.assertEqual(len(modelData2.blends), 1)
476 for blendId, blendData1 in modelData.blends[parentId].children.items():
477 blendData2 = modelData2.blends[parentId].children[blendId]
478 self._test_blend(blendData1, blendData2, model_psf, psf, bands)
480 # Test extracting two blends
481 modelData2 = butler.get("scarlet_model_data", dataId={}, parameters={"blend_id": [2, 3]})
482 self.assertEqual(len(modelData2.blends), 2)
483 for parentId in [2, 3]:
484 parentData1 = modelData.blends[parentId]
485 parentData2 = modelData2.blends[parentId]
486 self.assertEqual(len(parentData1.children), len(parentData2.children))
487 for blendId in parentData1.children.keys():
488 blendData1 = parentData1.children[blendId]
489 blendData2 = parentData2.children[blendId]
490 self._test_blend(blendData1, blendData2, model_psf, psf, bands)
492 def test_legacy_model(self):
493 repo = self._setup_butler()
494 storageClass = StorageClass(
495 "LsstScarletModelData",
496 pytype=mes.io.LsstScarletModelData,
497 )
498 datasetType = DatasetType(
499 "old_scarlet_model_data",
500 dimensions=(),
501 storageClass=storageClass,
502 universe=repo.dimensions,
503 )
504 ref = DatasetRef(
505 datasetType,
506 run="test_ingestion",
507 dataId={},
508 )
509 dataset = FileDataset(
510 path=os.path.join(TESTDIR, "data", "v29_models.json"),
511 formatter="lsst.daf.butler.formatters.json.JsonFormatter",
512 refs=[ref],
513 )
515 # Ingest the legacy model into the butler
516 butler = makeTestCollection(repo, uniqueId="ingestion")
517 repo.registry.registerDatasetType(datasetType)
518 butler.ingest(dataset)
520 model = butler.get("old_scarlet_model_data", dataId={})
521 self.assertEqual(len(model.blends), 2)
523 test = butler.get("old_scarlet_model_data", dataId={}, parameters={"blend_id": 3495976385350991873})
524 self.assertEqual(len(test.blends), 1)
526 def test_older_legacy_model(self):
527 repo = self._setup_butler()
528 oldStorageClass = StorageClass(
529 "ScarletModelData",
530 pytype=lsst.scarlet.lite.io.ScarletModelData,
531 )
532 oldDatasetType = DatasetType(
533 "old_scarlet_model_data",
534 dimensions=(),
535 storageClass=oldStorageClass,
536 universe=repo.dimensions,
537 )
538 ref = DatasetRef(
539 oldDatasetType,
540 run="test_ingestion",
541 dataId={},
542 )
543 dataset = FileDataset(
544 path=os.path.join(TESTDIR, "data", "v29_models.json"),
545 formatter="lsst.daf.butler.formatters.json.JsonFormatter",
546 refs=[ref],
547 )
549 # Ingest the legacy model into the butler
550 butler = makeTestCollection(repo, uniqueId="ingestion")
551 repo.registry.registerDatasetType(oldDatasetType)
552 butler.ingest(dataset)
554 # Load the base repo config from the repository
555 base_config = Config(os.path.join(self.repo_dir, "butler.yaml"))
557 # Load the storage class override config
558 override_path = os.path.join(
559 os.path.dirname(lsst.daf.butler.__file__),
560 "configs",
561 "storageClasses.yaml"
562 )
563 override_config = Config(override_path)
565 # Merge the configs (update base with override)
566 base_config.update(override_config)
568 # Create Butler with the merged config
569 # The config now contains both the repo info and
570 # the storage class overrides
571 newButler = Butler.from_config(base_config, collections=butler.collections)
573 model = newButler.get("old_scarlet_model_data", dataId={}, storageClass="LsstScarletModelData")
574 self.assertEqual(len(model.blends), 2)
575 self.assertEqual(len(model.isolated), 0)
577 def _test_blend(self, blendData1, blendData2, model_psf, psf, bands):
578 # Test that two ScarletBlendData objects are equal
579 # up to machine precision.
580 self.assertTupleEqual(blendData1.origin, blendData2.origin)
581 self.assertEqual(len(blendData1.sources), len(blendData2.sources))
583 # Test that the two blends are equal up to machine precision
584 # once converted into scarlet lite Blend objects.
585 blend1 = blendData1.minimal_data_to_blend(
586 model_psf,
587 psf,
588 bands,
589 dtype=np.float32,
590 )
591 blend2 = blendData2.minimal_data_to_blend(
592 model_psf,
593 psf,
594 bands,
595 dtype=np.float32,
596 )
597 np.testing.assert_almost_equal(blend1.get_model().data, blend2.get_model().data)
599 def _setup_butler(self):
600 # Initialize a Butler to test persistence
601 repo_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
602 self.repo_dir = repo_dir.name
603 self.addCleanup(tempfile.TemporaryDirectory.cleanup, repo_dir)
604 config = Config()
605 config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore"
606 repo = makeTestRepo(repo_dir.name, config=config)
607 storageClass = StorageClass(
608 "LsstScarletModelData",
609 pytype=mes.io.LsstScarletModelData,
610 parameters=('blend_id',),
611 delegate="lsst.meas.extensions.scarlet.io.ScarletModelDelegate",
612 )
613 datasetType = DatasetType(
614 "scarlet_model_data",
615 dimensions=(),
616 storageClass=storageClass,
617 universe=repo.dimensions,
618 )
619 repo.registry.registerDatasetType(datasetType)
620 return repo
623class BadPsf(Psf):
624 def __init__(self, validPoint: Point2D, psf: GaussianPsf):
625 self.validPoint = validPoint
626 self.psf = psf
627 super().__init__()
629 def computeKernelImage(self, location: Point2D):
630 if location == self.validPoint:
631 return self.psf.computeKernelImage(location)
632 raise InvalidPsfError(f"Invalid PSF at location {location}")
635class TestUtils(lsst.utils.tests.TestCase):
636 def setUp(self):
637 self.bands = tuple("gri")
639 def test_box_transforms(self):
640 # Test scarlet Box to BBox
641 box = scl.Box((25, 30), (17, 5))
642 bbox = mes.utils.scarletBoxToBBox(box)
643 x0, y0 = bbox.getMin()
644 width = bbox.getWidth()
645 height = bbox.getHeight()
646 self.assertTupleEqual((y0, x0), box.origin)
647 self.assertTupleEqual((height, width), box.shape)
649 # Test back to scarlet Box
650 newBox = mes.utils.bboxToScarletBox(bbox)
651 self.assertTupleEqual(newBox.origin, box.origin)
652 self.assertTupleEqual(newBox.shape, box.shape)
654 def test_computeNearestPsfGood(self):
655 # Test that using a valid PSF works normally
656 psf, psfImage = self._generateGoodPsf()
657 coadd = self._generateCoadd(psf)
659 # Test that computing the PSF works
660 derivedPsf, center, dist = mes.utils.computeNearestPsf(coadd, None, "g", Point2D(25, 25))
661 np.testing.assert_array_equal(derivedPsf.array, psfImage)
662 self.assertEqual(center, Point2D(25, 25))
663 self.assertEqual(dist, 0)
665 def test_computeNearestPsfRecoverable(self):
666 # Test that using a PSF not defined at the initial location
667 # will fallback to a valid location.
668 psf, psfImage = self._generateGoodPsf()
669 coadd = self._generateCoadd(BadPsf(Point2D(1, 1), psf))
670 catalog = self._generateCatalog(self.bands, [[(1, 1, 10)]])
672 # Test that computing the PSF works after finding a new location.
673 # Since the PSF above is *only* defined at (1, 1) it will fail to
674 # compute a PSF image at (4, 5) but should fall back to (1, 1).
675 derivedPsf, center, dist = mes.utils.computeNearestPsf(coadd, catalog, None, Point2D(4, 5))
676 np.testing.assert_array_equal(derivedPsf.array, psfImage)
677 self.assertEqual(center, Point2I(1, 1))
678 self.assertEqual(dist, 5)
680 def test_computeNearestPsfBad(self):
681 # Test that a PSF that cannot find a matching location returns None
682 psf = self._generateGoodPsf()
683 coadd = self._generateCoadd(BadPsf(Point2D(1, 1), psf))
684 catalog = self._generateCatalog(self.bands)
686 # Test that computing the PSF cannot generate a PSF
687 derivedPsf, center, dist = mes.utils.computeNearestPsf(coadd, catalog, None, Point2D(4, 5))
688 self.assertIsNone(derivedPsf)
689 self.assertIsNone(center)
690 self.assertIsNone(dist)
692 def test_computeNearestPsfMultiBandGood(self):
693 # Test that a valid PSF in every band works normally
694 bands = tuple("gri")
695 psfs, psfImage = self._generateMultibandPsf([1.0, 1.2, 1.4])
696 mCoadd = self._generateMultibandCoadd(psfs, bands)
698 # Test that computing the PSF works
699 psfArray, newCoadd = mes.utils.computeNearestPsfMultiBand(mCoadd, Point2D(25, 25), None)
700 np.testing.assert_array_equal(psfArray, psfImage)
701 self.assertTupleEqual(newCoadd.bands, bands)
703 def test_computeNearestPsfMultiBandRecoverable(self):
704 # Test that a Psf at a different location is still recoverable
705 bands = tuple("gri")
706 psfs, psfImage = self._generateMultibandPsf([1.0, 1.2, 1.4])
707 psfs[1] = BadPsf(Point2D(1, 1), psfs[1])
708 mCoadd = self._generateMultibandCoadd(psfs, bands)
709 catalog = self._generateCatalog(self.bands, [[(1, 1, 10)]])
711 # Test that computing the PSF works because the catalog has a peak
712 # at the location of the BadPsf.
713 psfArray, newCoadd = mes.utils.computeNearestPsfMultiBand(mCoadd, Point2D(25, 25), catalog)
714 np.testing.assert_array_equal(psfArray, psfImage)
715 self.assertTupleEqual(newCoadd.bands, bands)
717 def test_computeNearestPsfMultiBandIncomplete(self):
718 # Test that missing a PSF in one band returns a PSF and
719 # an exposure that are missing bands.
720 bands = tuple("gri")
721 psfs, psfImage = self._generateMultibandPsf([1.0, 1.2, 1.4])
722 psfs[1] = BadPsf(Point2D(1, 1), psfs[1])
723 mCoadd = self._generateMultibandCoadd(psfs, bands)
724 catalog = self._generateCatalog(self.bands)
726 # Test that computing the PSF works for the g- and i-band PSFs that
727 # are not BadPsf.
728 psfArray, newCoadd = mes.utils.computeNearestPsfMultiBand(mCoadd, Point2D(25, 25), catalog)
729 np.testing.assert_array_equal(psfArray, np.delete(psfImage, 1, axis=0))
730 self.assertTupleEqual(newCoadd.bands, tuple("gi"))
732 def test_computeNearestPsfMultiBandBad(self):
733 # Test that None is returned if none of the PSFs can be computed
734 bands = tuple("gri")
735 psfs, psfImage = self._generateMultibandPsf([1.0, 1.2, 1.4])
736 psfs = [BadPsf(Point2D(1, 1), psfs) for psf in psfs]
737 mCoadd = self._generateMultibandCoadd(psfs, bands)
738 catalog = self._generateCatalog(self.bands)
740 psfArray, newCoadd = mes.utils.computeNearestPsfMultiBand(mCoadd, Point2D(25, 25), catalog)
741 self.assertIsNone(psfArray)
742 self.assertIsNone(newCoadd)
744 def test_buildObservationBadPsfs(self):
745 # Test that creating an observation with all bad PSFs
746 # raises NoWorkFound
747 modelPsf = scl.utils.integrated_circular_gaussian(sigma=0.8).astype(np.float32)
748 bands = tuple("gri")
749 psfs, psfImage = self._generateMultibandPsf([1.0, 1.2, 1.4])
750 psfs = [BadPsf(Point2D(1, 1), psf) for psf in psfs]
751 mCoadd = self._generateMultibandCoadd(psfs, bands)
752 catalog = self._generateCatalog(self.bands)
754 # Test that building the observation fails without a catalog
755 with self.assertRaises(NoWorkFound):
756 mes.utils.buildObservation(modelPsf, Point2I(25, 25), mCoadd)
758 # Test that building the observation fails even with a catalog
759 with self.assertRaises(NoWorkFound):
760 mes.utils.buildObservation(modelPsf, Point2I(25, 25), mCoadd, catalog=catalog)
762 def _generateGoodPsf(self, sigma: float = 1.0):
763 # Generate a PSF and Image of the PSF
764 psfRadius = 20
765 psfShape = (2 * psfRadius + 1, 2 * psfRadius + 1)
766 psf = GaussianPsf(psfShape[1], psfShape[0], sigma)
767 psfImage = psf.computeImage(psf.getAveragePosition()).array
768 return psf, psfImage
770 def _generateMultibandPsf(self, sigmas: list[float]):
771 # Generate a multiband PSF with a BadPsf for each None value in sigmas
772 psfs = []
773 psfImages = []
774 for sigma in sigmas:
775 psf, psfImage = self._generateGoodPsf(sigma)
776 psfs.append(psf)
777 psfImages.append(psfImage)
778 return psfs, np.asarray(psfImages)
780 def _generateCoadd(self, psf: Psf):
781 # Create an empty exposure
782 masked_image = afwImage.MaskedImage(Extent2I(50, 50), dtype=np.float32)
783 coadd = afwImage.Exposure(masked_image, dtype=np.float32)
784 coadd.setPsf(psf)
785 return coadd
787 def _generateMultibandCoadd(self, psfs: Psf, bands: list[str]):
788 # Create an empty multi-band exposure
789 coadds = []
790 for psf in psfs:
791 coadds.append(self._generateCoadd(psf))
792 return afwImage.MultibandExposure.fromExposures(bands, coadds)
794 def _generateCatalog(self, bands, footprints: list[list[tuple[int, int, int]]] | None = None):
795 # Generate a catalog with a source for each footprint
796 if footprints is None:
797 footprints = []
798 schema = SourceTable.makeMinimalSchema()
799 peakSchema = PeakTable.makeMinimalSchema()
800 for band in bands:
801 schema.addField(f"merge_footprint_{band}", type="Flag")
802 peakSchema.addField(f"merge_peak_{band}", type="Flag")
804 table = SourceTable.make(schema)
805 catalog = SourceCatalog(table)
807 for peaks in footprints:
808 src = catalog.addNew()
809 footprint = Footprint(SpanSet(), peakSchema)
810 for peak in peaks:
811 footprint.addPeak(*peak)
812 src.setFootprint(footprint)
814 for band in bands:
815 src[f"merge_footprint_{band}"] = True
816 footprint.peaks[f"merge_peak_{band}"] = True
817 return catalog
820class MemoryTester(lsst.utils.tests.MemoryTestCase):
821 pass
824def setup_module(module):
825 lsst.utils.tests.init()
828if __name__ == "__main__": 828 ↛ 829line 828 didn't jump to line 829 because the condition on line 828 was never true
829 lsst.utils.tests.init()
830 unittest.main()