Coverage for tests/test_isPrimaryFlag.py: 26%
158 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-06 12:26 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-06 12:26 +0000
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
23import unittest
24import numpy as np
25import logging
27from lsst.geom import Point2I, Box2I, Extent2I
28from lsst.skymap import TractInfo
29from lsst.skymap.patchInfo import PatchInfo
30import lsst.afw.image as afwImage
31import lsst.utils.tests
32from lsst.utils import getPackageDir
33from lsst.pipe.tasks.characterizeImage import CharacterizeImageTask, CharacterizeImageConfig
34from lsst.pipe.tasks.calibrate import CalibrateTask, CalibrateConfig
35from lsst.meas.algorithms import SourceDetectionTask, SkyObjectsTask
36import lsst.meas.extensions.scarlet as mes
37from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask
38from lsst.meas.base import SingleFrameMeasurementTask
39from lsst.pipe.tasks.setPrimaryFlags import SetPrimaryFlagsTask, getPseudoSources
40from lsst.afw.table import SourceCatalog
43class NullTract(TractInfo):
44 """A Tract not contained in the MockSkyMap.
46 BaseSkyMap.findTract(coord) will always return a Tract,
47 even if the coord isn't located in the Tract.
48 In order to mimick this functionality we create a
49 NullTract for regions of the MockSkyMap that
50 aren't contained in any of the tracts.
51 """
52 def __init__(self):
53 pass
55 def getId(self):
56 return None
59class MockTractInfo:
60 """A Tract based on a bounding box and WCS.
62 Testing is made easier when we can specifically define
63 a Tract in terms of its bounding box in pixel coordinates
64 along with a WCS for the exposure.
66 Only the relevant methods from `TractInfo` needed to make
67 test pass are implemented here. Since this is just for
68 testing, it isn't sophisticated and requires developers to
69 ensure that the size of the bounding box is evenly divisible
70 by the number of patches in the Tract.
71 """
72 def __init__(self, name, bbox, wcs, numPatches):
73 self.name = name
74 self.bbox = bbox
75 self.wcs = wcs
76 self._numPatches = numPatches
77 assert bbox.getWidth()%numPatches[0] == 0
78 assert bbox.getHeight()%numPatches[1] == 0
79 self.patchWidth = bbox.getWidth()//numPatches[0]
80 self.patchHeight = bbox.getHeight()//numPatches[1]
82 def contains(self, coord):
83 pixel = self.wcs.skyToPixel(coord)
84 return self.bbox.contains(Point2I(pixel))
86 def getId(self):
87 return self.name
89 def getNumPatches(self):
90 return self._numPatches
92 def getPatchInfo(self, index):
93 x, y = index
94 width = self.patchWidth
95 height = self.patchHeight
97 x = x*self.patchWidth
98 y = y*self.patchHeight
100 bbox = Box2I(Point2I(x, y), Extent2I(width, height))
102 nx, ny = self._numPatches
103 sequentialIndex = nx*y + x
105 patchInfo = PatchInfo(
106 index=index,
107 innerBBox=bbox,
108 outerBBox=bbox,
109 sequentialIndex=sequentialIndex,
110 tractWcs=self.wcs
111 )
112 return patchInfo
114 def __getitem__(self, index):
115 return self.getPatchInfo(index)
117 def __iter__(self):
118 xNum, yNum = self.getNumPatches()
119 for y in range(yNum):
120 for x in range(xNum):
121 yield self.getPatchInfo((x, y))
124class MockSkyMap:
125 """A SkyMap based on a list of bounding boxes.
127 Testing is made easier when we can specifically define
128 a Tract in terms of its bounding box in pixel coordinates
129 along with a WCS for the exposure. This class allows us
130 to define the tract(s) in the SkyMap and create
131 them.
132 """
133 def __init__(self, bboxes, wcs, numPatches):
134 self.bboxes = bboxes
135 self.wcs = wcs
136 self.numPatches = numPatches
138 def __iter__(self):
139 for b, bbox in enumerate(self.bboxes):
140 yield self.generateTract(b)
142 def __getitem__(self, index):
143 return self.generateTract(index)
145 def generateTract(self, index):
146 return MockTractInfo(index, self.bboxes[index], self.wcs, self.numPatches)
148 def findTract(self, coord):
149 for tractInfo in self:
150 if tractInfo.contains(coord):
151 return tractInfo
153 return NullTract()
156class IsPrimaryTestCase(lsst.utils.tests.TestCase):
158 def setUp(self):
159 # Load sample input from disk
160 expPath = os.path.join(getPackageDir("pipe_tasks"), "tests", "data", "v695833-e0-c000-a00.sci.fits")
161 self.exposure = afwImage.ExposureF(expPath)
163 # Characterize the image (create PSF, etc.)
164 charImConfig = CharacterizeImageConfig()
165 charImConfig.measureApCorr.sourceSelector["science"].doSignalToNoise = False
166 charImTask = CharacterizeImageTask(config=charImConfig)
167 self.charImResults = charImTask.run(self.exposure)
169 # set log level so that warnings do not display
170 logging.getLogger("lsst.calibrate").setLevel(logging.ERROR)
172 def tearDown(self):
173 del self.exposure
174 self.charImResults
176 def testIsSinglePrimaryFlag(self):
177 """Tests detect_isPrimary column gets added when run, and that sources
178 labelled as detect_isPrimary are not sky sources and have no children.
179 """
180 calibConfig = CalibrateConfig()
181 calibConfig.doAstrometry = False
182 calibConfig.doPhotoCal = False
183 calibConfig.doComputeSummaryStats = False
184 calibTask = CalibrateTask(config=calibConfig)
185 calibResults = calibTask.run(self.charImResults.exposure)
186 outputCat = calibResults.outputCat
187 self.assertTrue("detect_isPrimary" in outputCat.schema.getNames())
188 # make sure all sky sources are flagged as not primary
189 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["sky_source"])), 0)
190 # make sure all parent sources are flagged as not primary
191 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["deblend_nChild"] > 0)), 0)
193 with self.assertRaises(KeyError):
194 outputCat.getSchema().find("detect_isDelendedModelPrimary")
196 def testIsScarletPrimaryFlag(self):
197 """Test detect_isPrimary column when scarlet is used as the deblender
198 """
199 # We need a multiband coadd for scarlet,
200 # even though there is only one band
201 coadds = afwImage.MultibandExposure.fromExposures(["test"], [self.exposure])
203 # Create a SkyMap with a tract that contains a portion of the image,
204 # subdivided into 3x3 patches
205 wcs = self.exposure.getWcs()
206 tractBBox = Box2I(Point2I(100, 100), Extent2I(900, 900))
207 skyMap = MockSkyMap([tractBBox], wcs, (3, 3))
208 tractInfo = skyMap[0]
209 patchInfo = tractInfo[0, 0]
210 patchBBox = patchInfo.getInnerBBox()
212 schema = SourceCatalog.Table.makeMinimalSchema()
213 # Initialize the detection task
214 detectionTask = SourceDetectionTask(schema=schema)
216 # Initialize the fake source injection task
217 skyConfig = SkyObjectsTask.ConfigClass()
218 skySourcesTask = SkyObjectsTask(name="skySources", config=skyConfig)
219 schema.addField("merge_peak_sky", type="Flag")
221 # Initialize the deblender task
222 scarletConfig = ScarletDeblendTask.ConfigClass()
223 scarletConfig.maxIter = 20
224 scarletConfig.columnInheritance["merge_peak_sky"] = "merge_peak_sky"
225 deblendTask = ScarletDeblendTask(schema=schema, config=scarletConfig)
227 # We'll customize the configuration of measurement to just run the
228 # minimal number of plugins to make setPrimaryFlags work.
229 measureConfig = SingleFrameMeasurementTask.ConfigClass()
230 measureConfig.plugins.names = ["base_SdssCentroid", "base_SkyCoord"]
231 measureConfig.slots.psfFlux = None
232 measureConfig.slots.apFlux = None
233 measureConfig.slots.shape = None
234 measureConfig.slots.modelFlux = None
235 measureConfig.slots.calibFlux = None
236 measureConfig.slots.gaussianFlux = None
237 measureTask = SingleFrameMeasurementTask(config=measureConfig, schema=schema)
238 primaryConfig = SetPrimaryFlagsTask.ConfigClass()
239 setPrimaryTask = SetPrimaryFlagsTask(config=primaryConfig, schema=schema,
240 name="setPrimaryFlags", isSingleFrame=False)
242 table = SourceCatalog.Table.make(schema)
243 # detect sources
244 detectionResult = detectionTask.run(table, coadds["test"])
245 catalog = detectionResult.sources
246 # add fake sources
247 skySources = skySourcesTask.run(mask=self.exposure.mask, seed=0)
248 for foot in skySources[:5]:
249 src = catalog.addNew()
250 src.setFootprint(foot)
251 src.set("merge_peak_sky", True)
252 # deblend
253 catalog, modelData = deblendTask.run(coadds, catalog)
254 # Attach footprints to the catalog
255 mes.io.updateCatalogFootprints(
256 modelData=modelData,
257 catalog=catalog,
258 band="test",
259 imageForRedistribution=coadds["test"],
260 removeScarletData=True,
261 updateFluxColumns=True,
262 )
263 # measure
264 measureTask.run(catalog, self.exposure)
265 outputCat = catalog
266 # Set the primary flags
267 setPrimaryTask.run(outputCat, skyMap=skyMap, tractInfo=tractInfo, patchInfo=patchInfo)
269 # There should be the same number of deblenedPrimary and
270 # deblendedModelPrimary sources,
271 # since they both have the same blended sources and only differ
272 # over which model to use for the isolated sources.
273 isPseudo = getPseudoSources(outputCat, primaryConfig.pseudoFilterList, schema, setPrimaryTask.log)
274 self.assertEqual(
275 np.sum(outputCat["detect_isDeblendedSource"] & ~isPseudo),
276 np.sum(outputCat["detect_isDeblendedModelSource"]))
278 # Check that the sources contained in a tract are all marked appropriately
279 x = outputCat["slot_Centroid_x"]
280 y = outputCat["slot_Centroid_y"]
281 tractInner = tractBBox.contains(x, y)
282 np.testing.assert_array_equal(outputCat["detect_isTractInner"], tractInner)
284 # Check that the sources contained in a patch are all marked appropriately
285 patchInner = patchBBox.contains(x, y)
286 np.testing.assert_array_equal(outputCat["detect_isPatchInner"], patchInner)
288 # make sure all sky sources are flagged as not primary
289 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["merge_peak_sky"])), 0)
291 # Check that sky objects have not been deblended
292 np.testing.assert_array_equal(
293 isPseudo,
294 isPseudo & (outputCat["deblend_nChild"] == 0)
295 )
298class MemoryTester(lsst.utils.tests.MemoryTestCase):
299 pass
302def setup_module(module):
303 lsst.utils.tests.init()
306if __name__ == "__main__": 306 ↛ 307line 306 didn't jump to line 307, because the condition on line 306 was never true
307 lsst.utils.tests.init()
308 unittest.main()