Coverage for tests / test_isPrimaryFlag.py: 25%
165 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 09:21 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 09:21 +0000
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
23import unittest
24import numpy as np
26from lsst.geom import Point2I, Box2I, Extent2I
27from lsst.skymap import TractInfo
28from lsst.skymap.patchInfo import PatchInfo
29import lsst.afw.image as afwImage
30import lsst.utils.tests
31from lsst.pipe.tasks.characterizeImage import CharacterizeImageTask, CharacterizeImageConfig
32from lsst.pipe.tasks.calibrate import CalibrateTask, CalibrateConfig
33from lsst.meas.algorithms import SourceDetectionTask, SkyObjectsTask, SetPrimaryFlagsTask
34import lsst.meas.extensions.scarlet as mes
35from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask
36from lsst.meas.extensions.scarlet.deconvolveExposureTask import DeconvolveExposureTask
37from lsst.meas.base import SingleFrameMeasurementTask
38from lsst.afw.table import SourceCatalog
40TESTDIR = os.path.abspath(os.path.dirname(__file__))
43class NullTract(TractInfo):
44 """A Tract not contained in the MockSkyMap.
46 BaseSkyMap.findTract(coord) will always return a Tract,
47 even if the coord isn't located in the Tract.
48 In order to mimick this functionality we create a
49 NullTract for regions of the MockSkyMap that
50 aren't contained in any of the tracts.
51 """
52 def __init__(self):
53 pass
55 def getId(self):
56 return None
59class MockTractInfo:
60 """A Tract based on a bounding box and WCS.
62 Testing is made easier when we can specifically define
63 a Tract in terms of its bounding box in pixel coordinates
64 along with a WCS for the exposure.
66 Only the relevant methods from `TractInfo` needed to make
67 test pass are implemented here. Since this is just for
68 testing, it isn't sophisticated and requires developers to
69 ensure that the size of the bounding box is evenly divisible
70 by the number of patches in the Tract.
71 """
72 def __init__(self, name, bbox, wcs, numPatches):
73 self.name = name
74 self.bbox = bbox
75 self.wcs = wcs
76 self._numPatches = numPatches
77 assert bbox.getWidth()%numPatches[0] == 0
78 assert bbox.getHeight()%numPatches[1] == 0
79 self.patchWidth = bbox.getWidth()//numPatches[0]
80 self.patchHeight = bbox.getHeight()//numPatches[1]
82 def contains(self, coord):
83 pixel = self.wcs.skyToPixel(coord)
84 return self.bbox.contains(Point2I(pixel))
86 def getId(self):
87 return self.name
89 def getNumPatches(self):
90 return self._numPatches
92 def getPatchInfo(self, index):
93 x, y = index
94 width = self.patchWidth
95 height = self.patchHeight
97 x = x*self.patchWidth
98 y = y*self.patchHeight
100 bbox = Box2I(Point2I(x, y), Extent2I(width, height))
102 nx, ny = self._numPatches
103 sequentialIndex = nx*y + x
105 patchInfo = PatchInfo(
106 index=index,
107 innerBBox=bbox,
108 outerBBox=bbox,
109 sequentialIndex=sequentialIndex,
110 tractWcs=self.wcs
111 )
112 return patchInfo
114 def __getitem__(self, index):
115 return self.getPatchInfo(index)
117 def __iter__(self):
118 xNum, yNum = self.getNumPatches()
119 for y in range(yNum):
120 for x in range(xNum):
121 yield self.getPatchInfo((x, y))
124class MockSkyMap:
125 """A SkyMap based on a list of bounding boxes.
127 Testing is made easier when we can specifically define
128 a Tract in terms of its bounding box in pixel coordinates
129 along with a WCS for the exposure. This class allows us
130 to define the tract(s) in the SkyMap and create
131 them.
132 """
133 def __init__(self, bboxes, wcs, numPatches):
134 self.bboxes = bboxes
135 self.wcs = wcs
136 self.numPatches = numPatches
138 def __iter__(self):
139 for b, bbox in enumerate(self.bboxes):
140 yield self.generateTract(b)
142 def __getitem__(self, index):
143 return self.generateTract(index)
145 def generateTract(self, index):
146 return MockTractInfo(index, self.bboxes[index], self.wcs, self.numPatches)
148 def findTract(self, coord):
149 for tractInfo in self:
150 if tractInfo.contains(coord):
151 return tractInfo
153 return NullTract()
156class IsPrimaryTestCase(lsst.utils.tests.TestCase):
158 def setUp(self):
159 # Load sample input from disk
160 expPath = os.path.join(TESTDIR, "data", "v695833-e0-c000-a00.sci.fits")
161 self.exposure = afwImage.ExposureF(expPath)
163 # Characterize the image (create PSF, etc.)
164 charImConfig = CharacterizeImageConfig()
165 charImConfig.measureApCorr.sourceSelector["science"].doSignalToNoise = False
166 charImTask = CharacterizeImageTask(config=charImConfig)
167 self.charImResults = charImTask.run(self.exposure)
169 def tearDown(self):
170 del self.exposure
171 self.charImResults
173 def testIsSinglePrimaryFlag(self):
174 """Tests detect_isPrimary column gets added when run, and that sources
175 labelled as detect_isPrimary are not sky sources and have no children.
176 """
177 calibConfig = CalibrateConfig()
178 calibConfig.doAstrometry = False
179 calibConfig.doPhotoCal = False
180 calibConfig.doComputeSummaryStats = False
181 calibTask = CalibrateTask(config=calibConfig)
182 calibResults = calibTask.run(self.charImResults.exposure)
183 outputCat = calibResults.outputCat
184 self.assertTrue("detect_isPrimary" in outputCat.schema.getNames())
185 # make sure all sky sources are flagged as not primary
186 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["sky_source"])), 0)
187 # make sure all parent sources are flagged as not primary
188 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["deblend_nChild"] > 0)), 0)
190 with self.assertRaises(KeyError):
191 outputCat.getSchema().find("detect_isDelendedModelPrimary")
193 def testIsScarletPrimaryFlag(self):
194 """Test detect_isPrimary column when scarlet is used as the deblender
195 """
196 # We need a multiband coadd for scarlet,
197 # even though there is only one band
198 coadds = afwImage.MultibandExposure.fromExposures(["test"], [self.exposure])
200 # Create a SkyMap with a tract that contains a portion of the image,
201 # subdivided into 3x3 patches
202 wcs = self.exposure.getWcs()
203 tractBBox = Box2I(Point2I(100, 100), Extent2I(900, 900))
204 skyMap = MockSkyMap([tractBBox], wcs, (3, 3))
205 tractInfo = skyMap[0]
206 patchInfo = tractInfo[0, 0]
207 patchBBox = patchInfo.getInnerBBox()
209 schema = SourceCatalog.Table.makeMinimalSchema()
210 # Initialize the detection task
211 detectionTask = SourceDetectionTask(schema=schema)
213 # Initialize the fake source injection task
214 skyConfig = SkyObjectsTask.ConfigClass()
215 skySourcesTask = SkyObjectsTask(name="skySources", config=skyConfig)
216 schema.addField("merge_peak_sky", type="Flag")
218 # Initialize the deconvolution task
219 deconvolveConfig = DeconvolveExposureTask.ConfigClass()
220 deconvolveTask = DeconvolveExposureTask(config=deconvolveConfig)
222 # Initialize the deblender task
223 scarletConfig = ScarletDeblendTask.ConfigClass()
224 scarletConfig.maxIter = 20
225 scarletConfig.columnInheritance["merge_peak_sky"] = "merge_peak_sky"
226 scarletConfig.processSingles = True
227 deblendTask = ScarletDeblendTask(schema=schema, config=scarletConfig)
229 # We'll customize the configuration of measurement to just run the
230 # minimal number of plugins to make setPrimaryFlags work.
231 # As of DM-51670 we also include `base_PsfFlux` to ensure that
232 # the measurement plugins run correctly with the split between
233 # parent and child catalogs.
234 measureConfig = SingleFrameMeasurementTask.ConfigClass()
235 measureConfig.plugins.names = ["base_SdssCentroid", "base_SkyCoord", "base_PsfFlux"]
236 measureConfig.slots.apFlux = None
237 measureConfig.slots.shape = None
238 measureConfig.slots.modelFlux = None
239 measureConfig.slots.calibFlux = None
240 measureConfig.slots.gaussianFlux = None
241 measureTask = SingleFrameMeasurementTask(config=measureConfig, schema=schema)
242 setPrimaryTask = SetPrimaryFlagsTask(schema=schema, isSingleFrame=False)
244 table = SourceCatalog.Table.make(schema)
245 # detect sources
246 detectionResult = detectionTask.run(table, coadds["test"])
247 catalog = detectionResult.sources
248 # add fake sources
249 skySources = skySourcesTask.run(mask=self.exposure.mask, seed=0)
250 for foot in skySources[:5]:
251 src = catalog.addNew()
252 src.setFootprint(foot)
253 src.set("merge_peak_sky", True)
254 # deconvolve the images
255 deconvolved = deconvolveTask.run(coadds["test"], catalog).deconvolved
256 mDeconvolved = afwImage.MultibandExposure.fromExposures(["test"], [deconvolved])
257 # deblend
258 # This is a hack because the variance is not calibrated properly
259 # (it is 3 orders of magnitude too high), which causes the deblender
260 # to improperly deblend most sources due to the sparsity constraint.
261 coadds.variance.array[:] = 2e-1
262 mDeconvolved.variance.array[:] = 2e-1
263 result = deblendTask.run(coadds, mDeconvolved, catalog)
264 modelData = result.scarletModelData
265 catalog = result.deblendedCatalog
266 # Attach footprints to the catalog
267 mes.io.updateCatalogFootprints(
268 modelData=modelData,
269 catalog=catalog,
270 band="test",
271 imageForRedistribution=coadds["test"],
272 removeScarletData=True,
273 updateFluxColumns=True,
274 )
276 # measure
277 measureTask.run(catalog, self.exposure)
278 outputCat = catalog
279 # Set the primary flags
280 setPrimaryTask.run(outputCat, skyMap=skyMap, tractInfo=tractInfo, patchInfo=patchInfo)
282 # There should be the same number of deblenedPrimary and
283 # deblendedModelPrimary sources,
284 # since they both have the same blended sources and only differ
285 # over which model to use for the isolated sources.
286 isPseudo = outputCat["merge_peak_sky"]
288 # Check that all 5 pseudo-sources were created
289 self.assertEqual(np.sum(isPseudo), 5)
291 self.assertEqual(
292 np.sum(outputCat["detect_isDeblendedSource"] & ~isPseudo),
293 np.sum(outputCat["detect_isDeblendedModelSource"]))
295 # Check that the sources contained in a tract are all marked appropriately
296 x = outputCat["slot_Centroid_x"]
297 y = outputCat["slot_Centroid_y"]
298 tractInner = tractBBox.contains(x, y)
299 np.testing.assert_array_equal(outputCat["detect_isTractInner"], tractInner)
301 # Check that the sources contained in a patch are all marked appropriately
302 patchInner = patchBBox.contains(x, y)
303 np.testing.assert_array_equal(outputCat["detect_isPatchInner"], patchInner)
305 # make sure all sky sources are flagged as not primary
306 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["merge_peak_sky"])), 0)
308 # Check that sky objects have not been deblended
309 # (deblended sources have parent > 0)
310 np.testing.assert_array_equal(
311 isPseudo,
312 isPseudo & (outputCat["parent"] == 0)
313 )
315 # Check that measurements were performed on all of the children
316 self.assertTrue(np.all(outputCat["base_PsfFlux_instFlux"] != 0) and np.all(np.isfinite(
317 outputCat["base_PsfFlux_instFlux"])))
320class MemoryTester(lsst.utils.tests.MemoryTestCase):
321 pass
324def setup_module(module):
325 lsst.utils.tests.init()
328if __name__ == "__main__": 328 ↛ 329line 328 didn't jump to line 329 because the condition on line 328 was never true
329 lsst.utils.tests.init()
330 unittest.main()