Coverage for tests/test_isPrimaryFlag.py: 26%
156 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-28 03:33 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-28 03:33 -0800
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
23import unittest
24import numpy as np
25import logging
27from lsst.geom import Point2I, Box2I, Extent2I
28from lsst.skymap import TractInfo
29from lsst.skymap.patchInfo import PatchInfo
30import lsst.afw.image as afwImage
31import lsst.utils.tests
32from lsst.utils import getPackageDir
33from lsst.pipe.tasks.characterizeImage import CharacterizeImageTask, CharacterizeImageConfig
34from lsst.pipe.tasks.calibrate import CalibrateTask, CalibrateConfig
35from lsst.meas.algorithms import SourceDetectionTask, SkyObjectsTask
36from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask
37from lsst.meas.base import SingleFrameMeasurementTask
38from lsst.pipe.tasks.setPrimaryFlags import SetPrimaryFlagsTask, getPseudoSources
39from lsst.afw.table import SourceCatalog
42class NullTract(TractInfo):
43 """A Tract not contained in the MockSkyMap.
45 BaseSkyMap.findTract(coord) will always return a Tract,
46 even if the coord isn't located in the Tract.
47 In order to mimick this functionality we create a
48 NullTract for regions of the MockSkyMap that
49 aren't contained in any of the tracts.
50 """
51 def __init__(self):
52 pass
54 def getId(self):
55 return None
58class MockTractInfo:
59 """A Tract based on a bounding box and WCS.
61 Testing is made easier when we can specifically define
62 a Tract in terms of its bounding box in pixel coordinates
63 along with a WCS for the exposure.
65 Only the relevant methods from `TractInfo` needed to make
66 test pass are implemented here. Since this is just for
67 testing, it isn't sophisticated and requires developers to
68 ensure that the size of the bounding box is evenly divisible
69 by the number of patches in the Tract.
70 """
71 def __init__(self, name, bbox, wcs, numPatches):
72 self.name = name
73 self.bbox = bbox
74 self.wcs = wcs
75 self._numPatches = numPatches
76 assert bbox.getWidth()%numPatches[0] == 0
77 assert bbox.getHeight()%numPatches[1] == 0
78 self.patchWidth = bbox.getWidth()//numPatches[0]
79 self.patchHeight = bbox.getHeight()//numPatches[1]
81 def contains(self, coord):
82 pixel = self.wcs.skyToPixel(coord)
83 return self.bbox.contains(Point2I(pixel))
85 def getId(self):
86 return self.name
88 def getNumPatches(self):
89 return self._numPatches
91 def getPatchInfo(self, index):
92 x, y = index
93 width = self.patchWidth
94 height = self.patchHeight
96 x = x*self.patchWidth
97 y = y*self.patchHeight
99 bbox = Box2I(Point2I(x, y), Extent2I(width, height))
101 nx, ny = self._numPatches
102 sequentialIndex = nx*y + x
104 patchInfo = PatchInfo(
105 index=index,
106 innerBBox=bbox,
107 outerBBox=bbox,
108 sequentialIndex=sequentialIndex,
109 tractWcs=self.wcs
110 )
111 return patchInfo
113 def __getitem__(self, index):
114 return self.getPatchInfo(index)
116 def __iter__(self):
117 xNum, yNum = self.getNumPatches()
118 for y in range(yNum):
119 for x in range(xNum):
120 yield self.getPatchInfo((x, y))
123class MockSkyMap:
124 """A SkyMap based on a list of bounding boxes.
126 Testing is made easier when we can specifically define
127 a Tract in terms of its bounding box in pixel coordinates
128 along with a WCS for the exposure. This class allows us
129 to define the tract(s) in the SkyMap and create
130 them.
131 """
132 def __init__(self, bboxes, wcs, numPatches):
133 self.bboxes = bboxes
134 self.wcs = wcs
135 self.numPatches = numPatches
137 def __iter__(self):
138 for b, bbox in enumerate(self.bboxes):
139 yield self.generateTract(b)
141 def __getitem__(self, index):
142 return self.generateTract(index)
144 def generateTract(self, index):
145 return MockTractInfo(index, self.bboxes[index], self.wcs, self.numPatches)
147 def findTract(self, coord):
148 for tractInfo in self:
149 if tractInfo.contains(coord):
150 return tractInfo
152 return NullTract()
155class IsPrimaryTestCase(lsst.utils.tests.TestCase):
157 def setUp(self):
158 # Load sample input from disk
159 expPath = os.path.join(getPackageDir("pipe_tasks"), "tests", "data", "v695833-e0-c000-a00.sci.fits")
160 self.exposure = afwImage.ExposureF(expPath)
162 # Characterize the image (create PSF, etc.)
163 charImConfig = CharacterizeImageConfig()
164 charImTask = CharacterizeImageTask(config=charImConfig)
165 self.charImResults = charImTask.run(self.exposure)
167 # set log level so that warnings do not display
168 logging.getLogger("lsst.calibrate").setLevel(logging.ERROR)
170 def tearDown(self):
171 del self.exposure
172 self.charImResults
174 def testIsSinglePrimaryFlag(self):
175 """Tests detect_isPrimary column gets added when run, and that sources
176 labelled as detect_isPrimary are not sky sources and have no children.
177 """
178 calibConfig = CalibrateConfig()
179 calibConfig.doAstrometry = False
180 calibConfig.doPhotoCal = False
181 calibConfig.doComputeSummaryStats = False
182 calibTask = CalibrateTask(config=calibConfig)
183 calibResults = calibTask.run(self.charImResults.exposure)
184 outputCat = calibResults.outputCat
185 self.assertTrue("detect_isPrimary" in outputCat.schema.getNames())
186 # make sure all sky sources are flagged as not primary
187 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["sky_source"])), 0)
188 # make sure all parent sources are flagged as not primary
189 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["deblend_nChild"] > 0)), 0)
191 with self.assertRaises(KeyError):
192 outputCat.getSchema().find("detect_isDelendedModelPrimary")
194 def testIsScarletPrimaryFlag(self):
195 """Test detect_isPrimary column when scarlet is used as the deblender
196 """
197 # We need a multiband coadd for scarlet,
198 # even though there is only one band
199 coadds = afwImage.MultibandExposure.fromExposures(["test"], [self.exposure])
201 # Create a SkyMap with a tract that contains a portion of the image,
202 # subdivided into 3x3 patches
203 wcs = self.exposure.getWcs()
204 tractBBox = Box2I(Point2I(100, 100), Extent2I(900, 900))
205 skyMap = MockSkyMap([tractBBox], wcs, (3, 3))
206 tractInfo = skyMap[0]
207 patchInfo = tractInfo[0, 0]
208 patchBBox = patchInfo.getInnerBBox()
210 schema = SourceCatalog.Table.makeMinimalSchema()
211 # Initialize the detection task
212 detectionTask = SourceDetectionTask(schema=schema)
214 # Initialize the fake source injection task
215 skyConfig = SkyObjectsTask.ConfigClass()
216 skySourcesTask = SkyObjectsTask(name="skySources", config=skyConfig)
217 schema.addField("merge_peak_sky", type="Flag")
219 # Initialize the deblender task
220 scarletConfig = ScarletDeblendTask.ConfigClass()
221 scarletConfig.maxIter = 20
222 scarletConfig.columnInheritance["merge_peak_sky"] = "merge_peak_sky"
223 deblendTask = ScarletDeblendTask(schema=schema, config=scarletConfig)
225 # We'll customize the configuration of measurement to just run the
226 # minimal number of plugins to make setPrimaryFlags work.
227 measureConfig = SingleFrameMeasurementTask.ConfigClass()
228 measureConfig.plugins.names = ["base_SdssCentroid", "base_SkyCoord"]
229 measureConfig.slots.psfFlux = None
230 measureConfig.slots.apFlux = None
231 measureConfig.slots.shape = None
232 measureConfig.slots.modelFlux = None
233 measureConfig.slots.calibFlux = None
234 measureConfig.slots.gaussianFlux = None
235 measureTask = SingleFrameMeasurementTask(config=measureConfig, schema=schema)
236 primaryConfig = SetPrimaryFlagsTask.ConfigClass()
237 setPrimaryTask = SetPrimaryFlagsTask(config=primaryConfig, schema=schema,
238 name="setPrimaryFlags", isSingleFrame=False)
240 table = SourceCatalog.Table.make(schema)
241 # detect sources
242 detectionResult = detectionTask.run(table, coadds["test"])
243 catalog = detectionResult.sources
244 # add fake sources
245 skySources = skySourcesTask.run(mask=self.exposure.mask, seed=0)
246 for foot in skySources[:5]:
247 src = catalog.addNew()
248 src.setFootprint(foot)
249 src.set("merge_peak_sky", True)
250 # deblend
251 catalog, modelData = deblendTask.run(coadds, catalog)
252 # Attach footprints to the catalog
253 modelData.updateCatalogFootprints(
254 catalog=catalog,
255 band="test",
256 psfModel=coadds["test"].getPsf(),
257 redistributeImage=None,
258 )
259 # measure
260 measureTask.run(catalog, self.exposure)
261 outputCat = catalog
262 # Set the primary flags
263 setPrimaryTask.run(outputCat, skyMap=skyMap, tractInfo=tractInfo, patchInfo=patchInfo)
265 # There should be the same number of deblenedPrimary and
266 # deblendedModelPrimary sources,
267 # since they both have the same blended sources and only differ
268 # over which model to use for the isolated sources.
269 isPseudo = getPseudoSources(outputCat, primaryConfig.pseudoFilterList, schema, setPrimaryTask.log)
270 self.assertEqual(
271 np.sum(outputCat["detect_isDeblendedSource"] & ~isPseudo),
272 np.sum(outputCat["detect_isDeblendedModelSource"]))
274 # Check that the sources contained in a tract are all marked appropriately
275 x = outputCat["slot_Centroid_x"]
276 y = outputCat["slot_Centroid_y"]
277 tractInner = tractBBox.contains(x, y)
278 np.testing.assert_array_equal(outputCat["detect_isTractInner"], tractInner)
280 # Check that the sources contained in a patch are all marked appropriately
281 patchInner = patchBBox.contains(x, y)
282 np.testing.assert_array_equal(outputCat["detect_isPatchInner"], patchInner)
284 # make sure all sky sources are flagged as not primary
285 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["merge_peak_sky"])), 0)
287 # Check that sky objects have not been deblended
288 np.testing.assert_array_equal(
289 isPseudo,
290 isPseudo & (outputCat["deblend_nChild"] == 0)
291 )
294class MemoryTester(lsst.utils.tests.MemoryTestCase):
295 pass
298def setup_module(module):
299 lsst.utils.tests.init()
302if __name__ == "__main__": 302 ↛ 303line 302 didn't jump to line 303, because the condition on line 302 was never true
303 lsst.utils.tests.init()
304 unittest.main()