Coverage for tests/test_isPrimaryFlag.py : 27%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
23import unittest
24import numpy as np
26from lsst.geom import Point2I, Box2I, Extent2I
27from lsst.skymap import TractInfo
28from lsst.skymap.patchInfo import PatchInfo
29import lsst.afw.image as afwImage
30import lsst.utils.tests
31from lsst.utils import getPackageDir
32from lsst.log import Log
33from lsst.pipe.tasks.characterizeImage import CharacterizeImageTask, CharacterizeImageConfig
34from lsst.pipe.tasks.calibrate import CalibrateTask, CalibrateConfig
35from lsst.meas.algorithms import SourceDetectionTask, SkyObjectsTask
36from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask
37from lsst.meas.base import SingleFrameMeasurementTask
38from lsst.pipe.tasks.setPrimaryFlags import SetPrimaryFlagsTask
39from lsst.afw.table import SourceCatalog
42class NullTract(TractInfo):
43 """A Tract not contained in the MockSkyMap.
45 BaseSkyMap.findTract(coord) will always return a Tract,
46 even if the coord isn't located in the Tract.
47 In order to mimick this functionality we create a
48 NullTract for regions of the MockSkyMap that
49 aren't contained in any of the tracts.
50 """
51 def __init__(self):
52 pass
54 def getId(self):
55 return None
58class MockTractInfo:
59 """A Tract based on a bounding box and WCS.
61 Testing is made easier when we can specifically define
62 a Tract in terms of its bounding box in pixel coordinates
63 along with a WCS for the exposure.
65 Only the relevant methods from `TractInfo` needed to make
66 test pass are implemented here. Since this is just for
67 testing, it isn't sophisticated and requires developers to
68 ensure that the size of the bounding box is evenly divisible
69 by the number of patches in the Tract.
70 """
71 def __init__(self, name, bbox, wcs, numPatches):
72 self.name = name
73 self.bbox = bbox
74 self.wcs = wcs
75 self.numPatchs = numPatches
76 assert bbox.getWidth()%numPatches == 0
77 assert bbox.getHeight()%numPatches == 0
78 self.patchWidth = bbox.getWidth()//numPatches
79 self.patchHeight = bbox.getHeight()//numPatches
81 def contains(self, coord):
82 pixel = self.wcs.skyToPixel(coord)
83 return self.bbox.contains(Point2I(pixel))
85 def getId(self):
86 return self.name
88 def getNumPatches(self):
89 return self._numPatches
91 def getPatchInfo(self, index):
92 x, y = index
93 width = self.patchWidth
94 height = self.patchHeight
96 x = x*self.patchWidth
97 y = y*self.patchHeight
99 bbox = Box2I(Point2I(x, y), Extent2I(width, height))
101 patchInfo = PatchInfo(
102 index=index,
103 innerBBox=bbox,
104 outerBBox=bbox,
105 )
106 return patchInfo
108 def __getitem__(self, index):
109 return self.getPatchInfo(index)
111 def __iter__(self):
112 xNum, yNum = self.getNumPatches()
113 for y in range(yNum):
114 for x in range(xNum):
115 yield self.getPatchInfo((x, y))
118class MockSkyMap:
119 """A SkyMap based on a list of bounding boxes.
121 Testing is made easier when we can specifically define
122 a Tract in terms of its bounding box in pixel coordinates
123 along with a WCS for the exposure. This class allows us
124 to define the tract(s) in the SkyMap and create
125 them.
126 """
127 def __init__(self, bboxes, wcs, numPatches):
128 self.bboxes = bboxes
129 self.wcs = wcs
130 self.numPatches = numPatches
132 def __iter__(self):
133 for b, bbox in enumerate(self.bboxes):
134 yield self.generateTract(b)
136 def __getitem__(self, index):
137 return self.generateTract(index)
139 def generateTract(self, index):
140 return MockTractInfo(index, self.bboxes[index], self.wcs, self.numPatches)
142 def findTract(self, coord):
143 for tractInfo in self:
144 if tractInfo.contains(coord):
145 return tractInfo
147 return NullTract()
150class IsPrimaryTestCase(lsst.utils.tests.TestCase):
152 def setUp(self):
153 # Load sample input from disk
154 expPath = os.path.join(getPackageDir("pipe_tasks"), "tests", "data", "v695833-e0-c000-a00.sci.fits")
155 self.exposure = afwImage.ExposureF(expPath)
157 # Characterize the image (create PSF, etc.)
158 charImConfig = CharacterizeImageConfig()
159 charImTask = CharacterizeImageTask(config=charImConfig)
160 self.charImResults = charImTask.run(self.exposure)
162 # set log level so that warnings do not display
163 Log.getLogger("calibrate").setLevel(Log.ERROR)
165 def tearDown(self):
166 del self.exposure
167 self.charImResults
169 def testIsSinglePrimaryFlag(self):
170 """Tests detect_isPrimary column gets added when run, and that sources
171 labelled as detect_isPrimary are not sky sources and have no children.
172 """
173 calibConfig = CalibrateConfig()
174 calibConfig.doAstrometry = False
175 calibConfig.doPhotoCal = False
176 calibTask = CalibrateTask(config=calibConfig)
177 calibResults = calibTask.run(self.charImResults.exposure)
178 outputCat = calibResults.outputCat
179 self.assertTrue("detect_isPrimary" in outputCat.schema.getNames())
180 # make sure all sky sources are flagged as not primary
181 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["sky_source"])), 0)
182 # make sure all parent sources are flagged as not primary
183 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["deblend_nChild"] > 0)), 0)
185 with self.assertRaises(KeyError):
186 outputCat.getSchema().find("detect_isDelendedModelPrimary")
188 def testIsScarletPrimaryFlag(self):
189 """Test detect_isPrimary column when scarlet is used as the deblender
190 """
191 # We need a multiband coadd for scarlet,
192 # even though there is only one band
193 coadds = afwImage.MultibandExposure.fromExposures(["test"], [self.exposure])
195 # Create a SkyMap with a tract that contains a portion of the image,
196 # subdivided into 3x3 patches
197 wcs = self.exposure.getWcs()
198 tractBBox = Box2I(Point2I(100, 100), Extent2I(900, 900))
199 skyMap = MockSkyMap([tractBBox], wcs, 3)
200 tractInfo = skyMap[0]
201 patchInfo = tractInfo[0, 0]
202 patchBBox = patchInfo.getInnerBBox()
204 schema = SourceCatalog.Table.makeMinimalSchema()
205 # Initialize the detection task
206 detectionTask = SourceDetectionTask(schema=schema)
208 # Initialize the fake source injection task
209 skyConfig = SkyObjectsTask.ConfigClass()
210 skySourcesTask = SkyObjectsTask(name="skySources", config=skyConfig)
211 schema.addField("merge_peak_sky", type="Flag")
213 # Initialize the deblender task
214 scarletConfig = ScarletDeblendTask.ConfigClass()
215 scarletConfig.maxIter = 20
216 scarletConfig.columnInheritance["merge_peak_sky"] = "merge_peak_sky"
217 deblendTask = ScarletDeblendTask(schema=schema, config=scarletConfig)
219 # We'll customize the configuration of measurement to just run the
220 # minimal number of plugins to make setPrimaryFlags work.
221 measureConfig = SingleFrameMeasurementTask.ConfigClass()
222 measureConfig.plugins.names = ["base_SdssCentroid", "base_SkyCoord"]
223 measureConfig.slots.psfFlux = None
224 measureConfig.slots.apFlux = None
225 measureConfig.slots.shape = None
226 measureConfig.slots.modelFlux = None
227 measureConfig.slots.calibFlux = None
228 measureConfig.slots.gaussianFlux = None
229 measureTask = SingleFrameMeasurementTask(config=measureConfig, schema=schema)
230 primaryConfig = SetPrimaryFlagsTask.ConfigClass()
231 setPrimaryTask = SetPrimaryFlagsTask(config=primaryConfig, schema=schema,
232 name="setPrimaryFlags", isSingleFrame=False)
234 table = SourceCatalog.Table.make(schema)
235 # detect sources
236 detectionResult = detectionTask.run(table, coadds["test"])
237 catalog = detectionResult.sources
238 # add fake sources
239 skySources = skySourcesTask.run(mask=self.exposure.mask, seed=0)
240 for foot in skySources[:5]:
241 src = catalog.addNew()
242 src.setFootprint(foot)
243 src.set("merge_peak_sky", True)
244 # deblend
245 result = deblendTask.run(coadds, catalog)
246 # measure
247 measureTask.run(result["test"], self.exposure)
248 outputCat = result["test"]
249 # Set the primary flags
250 setPrimaryTask.run(outputCat, skyMap=skyMap, tractInfo=tractInfo, patchInfo=patchInfo)
252 # There should be the same number of deblenedPrimary and
253 # deblendedModelPrimary sources,
254 # since they both have the same blended sources and only differ
255 # over which model to use for the isolated sources.
256 self.assertEqual(
257 np.sum(outputCat["detect_isDeblendedSource"]),
258 np.sum(outputCat["detect_isDeblendedModelSource"]))
260 # Check that the sources contained in a tract are all marked appropriately
261 x = outputCat["slot_Centroid_x"]
262 y = outputCat["slot_Centroid_y"]
263 tractInner = tractBBox.contains(x, y)
264 np.testing.assert_array_equal(outputCat["detect_isTractInner"], tractInner)
266 # Check that the sources contained in a patch are all marked appropriately
267 patchInner = patchBBox.contains(x, y)
268 np.testing.assert_array_equal(outputCat["detect_isPatchInner"], patchInner)
270 # make sure all sky sources are flagged as not primary
271 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["merge_peak_sky"])), 0)
274class MemoryTester(lsst.utils.tests.MemoryTestCase):
275 pass
278def setup_module(module):
279 lsst.utils.tests.init()
282if __name__ == "__main__": 282 ↛ 283line 282 didn't jump to line 283, because the condition on line 282 was never true
283 lsst.utils.tests.init()
284 unittest.main()