Coverage for tests/test_isPrimaryFlag.py: 30%
155 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-09-12 01:27 -0700
« prev ^ index » next coverage.py v6.4.4, created at 2022-09-12 01:27 -0700
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
23import unittest
24import numpy as np
25import logging
27from lsst.geom import Point2I, Box2I, Extent2I
28from lsst.skymap import TractInfo
29from lsst.skymap.patchInfo import PatchInfo
30import lsst.afw.image as afwImage
31import lsst.utils.tests
32from lsst.utils import getPackageDir
33from lsst.pipe.tasks.characterizeImage import CharacterizeImageTask, CharacterizeImageConfig
34from lsst.pipe.tasks.calibrate import CalibrateTask, CalibrateConfig
35from lsst.meas.algorithms import SourceDetectionTask, SkyObjectsTask
36from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask
37from lsst.meas.base import SingleFrameMeasurementTask
38from lsst.pipe.tasks.setPrimaryFlags import SetPrimaryFlagsTask, getPseudoSources
39from lsst.afw.table import SourceCatalog
42class NullTract(TractInfo):
43 """A Tract not contained in the MockSkyMap.
45 BaseSkyMap.findTract(coord) will always return a Tract,
46 even if the coord isn't located in the Tract.
47 In order to mimick this functionality we create a
48 NullTract for regions of the MockSkyMap that
49 aren't contained in any of the tracts.
50 """
51 def __init__(self):
52 pass
54 def getId(self):
55 return None
58class MockTractInfo:
59 """A Tract based on a bounding box and WCS.
61 Testing is made easier when we can specifically define
62 a Tract in terms of its bounding box in pixel coordinates
63 along with a WCS for the exposure.
65 Only the relevant methods from `TractInfo` needed to make
66 test pass are implemented here. Since this is just for
67 testing, it isn't sophisticated and requires developers to
68 ensure that the size of the bounding box is evenly divisible
69 by the number of patches in the Tract.
70 """
71 def __init__(self, name, bbox, wcs, numPatches):
72 self.name = name
73 self.bbox = bbox
74 self.wcs = wcs
75 self._numPatches = numPatches
76 assert bbox.getWidth()%numPatches[0] == 0
77 assert bbox.getHeight()%numPatches[1] == 0
78 self.patchWidth = bbox.getWidth()//numPatches[0]
79 self.patchHeight = bbox.getHeight()//numPatches[1]
81 def contains(self, coord):
82 pixel = self.wcs.skyToPixel(coord)
83 return self.bbox.contains(Point2I(pixel))
85 def getId(self):
86 return self.name
88 def getNumPatches(self):
89 return self._numPatches
91 def getPatchInfo(self, index):
92 x, y = index
93 width = self.patchWidth
94 height = self.patchHeight
96 x = x*self.patchWidth
97 y = y*self.patchHeight
99 bbox = Box2I(Point2I(x, y), Extent2I(width, height))
101 nx, ny = self._numPatches
102 sequentialIndex = nx*y + x
104 patchInfo = PatchInfo(
105 index=index,
106 innerBBox=bbox,
107 outerBBox=bbox,
108 sequentialIndex=sequentialIndex,
109 tractWcs=self.wcs
110 )
111 return patchInfo
113 def __getitem__(self, index):
114 return self.getPatchInfo(index)
116 def __iter__(self):
117 xNum, yNum = self.getNumPatches()
118 for y in range(yNum):
119 for x in range(xNum):
120 yield self.getPatchInfo((x, y))
123class MockSkyMap:
124 """A SkyMap based on a list of bounding boxes.
126 Testing is made easier when we can specifically define
127 a Tract in terms of its bounding box in pixel coordinates
128 along with a WCS for the exposure. This class allows us
129 to define the tract(s) in the SkyMap and create
130 them.
131 """
132 def __init__(self, bboxes, wcs, numPatches):
133 self.bboxes = bboxes
134 self.wcs = wcs
135 self.numPatches = numPatches
137 def __iter__(self):
138 for b, bbox in enumerate(self.bboxes):
139 yield self.generateTract(b)
141 def __getitem__(self, index):
142 return self.generateTract(index)
144 def generateTract(self, index):
145 return MockTractInfo(index, self.bboxes[index], self.wcs, self.numPatches)
147 def findTract(self, coord):
148 for tractInfo in self:
149 if tractInfo.contains(coord):
150 return tractInfo
152 return NullTract()
155class IsPrimaryTestCase(lsst.utils.tests.TestCase):
157 def setUp(self):
158 # Load sample input from disk
159 expPath = os.path.join(getPackageDir("pipe_tasks"), "tests", "data", "v695833-e0-c000-a00.sci.fits")
160 self.exposure = afwImage.ExposureF(expPath)
162 # Characterize the image (create PSF, etc.)
163 charImConfig = CharacterizeImageConfig()
164 charImTask = CharacterizeImageTask(config=charImConfig)
165 self.charImResults = charImTask.run(self.exposure)
167 # set log level so that warnings do not display
168 logging.getLogger("lsst.calibrate").setLevel(logging.ERROR)
170 def tearDown(self):
171 del self.exposure
172 self.charImResults
174 def testIsSinglePrimaryFlag(self):
175 """Tests detect_isPrimary column gets added when run, and that sources
176 labelled as detect_isPrimary are not sky sources and have no children.
177 """
178 calibConfig = CalibrateConfig()
179 calibConfig.doAstrometry = False
180 calibConfig.doPhotoCal = False
181 calibTask = CalibrateTask(config=calibConfig)
182 calibResults = calibTask.run(self.charImResults.exposure)
183 outputCat = calibResults.outputCat
184 self.assertTrue("detect_isPrimary" in outputCat.schema.getNames())
185 # make sure all sky sources are flagged as not primary
186 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["sky_source"])), 0)
187 # make sure all parent sources are flagged as not primary
188 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["deblend_nChild"] > 0)), 0)
190 with self.assertRaises(KeyError):
191 outputCat.getSchema().find("detect_isDelendedModelPrimary")
193 def testIsScarletPrimaryFlag(self):
194 """Test detect_isPrimary column when scarlet is used as the deblender
195 """
196 # We need a multiband coadd for scarlet,
197 # even though there is only one band
198 coadds = afwImage.MultibandExposure.fromExposures(["test"], [self.exposure])
200 # Create a SkyMap with a tract that contains a portion of the image,
201 # subdivided into 3x3 patches
202 wcs = self.exposure.getWcs()
203 tractBBox = Box2I(Point2I(100, 100), Extent2I(900, 900))
204 skyMap = MockSkyMap([tractBBox], wcs, (3, 3))
205 tractInfo = skyMap[0]
206 patchInfo = tractInfo[0, 0]
207 patchBBox = patchInfo.getInnerBBox()
209 schema = SourceCatalog.Table.makeMinimalSchema()
210 # Initialize the detection task
211 detectionTask = SourceDetectionTask(schema=schema)
213 # Initialize the fake source injection task
214 skyConfig = SkyObjectsTask.ConfigClass()
215 skySourcesTask = SkyObjectsTask(name="skySources", config=skyConfig)
216 schema.addField("merge_peak_sky", type="Flag")
218 # Initialize the deblender task
219 scarletConfig = ScarletDeblendTask.ConfigClass()
220 scarletConfig.maxIter = 20
221 scarletConfig.columnInheritance["merge_peak_sky"] = "merge_peak_sky"
222 deblendTask = ScarletDeblendTask(schema=schema, config=scarletConfig)
224 # We'll customize the configuration of measurement to just run the
225 # minimal number of plugins to make setPrimaryFlags work.
226 measureConfig = SingleFrameMeasurementTask.ConfigClass()
227 measureConfig.plugins.names = ["base_SdssCentroid", "base_SkyCoord"]
228 measureConfig.slots.psfFlux = None
229 measureConfig.slots.apFlux = None
230 measureConfig.slots.shape = None
231 measureConfig.slots.modelFlux = None
232 measureConfig.slots.calibFlux = None
233 measureConfig.slots.gaussianFlux = None
234 measureTask = SingleFrameMeasurementTask(config=measureConfig, schema=schema)
235 primaryConfig = SetPrimaryFlagsTask.ConfigClass()
236 setPrimaryTask = SetPrimaryFlagsTask(config=primaryConfig, schema=schema,
237 name="setPrimaryFlags", isSingleFrame=False)
239 table = SourceCatalog.Table.make(schema)
240 # detect sources
241 detectionResult = detectionTask.run(table, coadds["test"])
242 catalog = detectionResult.sources
243 # add fake sources
244 skySources = skySourcesTask.run(mask=self.exposure.mask, seed=0)
245 for foot in skySources[:5]:
246 src = catalog.addNew()
247 src.setFootprint(foot)
248 src.set("merge_peak_sky", True)
249 # deblend
250 catalog, modelData = deblendTask.run(coadds, catalog)
251 # Attach footprints to the catalog
252 modelData.updateCatalogFootprints(
253 catalog=catalog,
254 band="test",
255 psfModel=coadds["test"].getPsf(),
256 redistributeImage=None,
257 )
258 # measure
259 measureTask.run(catalog, self.exposure)
260 outputCat = catalog
261 # Set the primary flags
262 setPrimaryTask.run(outputCat, skyMap=skyMap, tractInfo=tractInfo, patchInfo=patchInfo)
264 # There should be the same number of deblenedPrimary and
265 # deblendedModelPrimary sources,
266 # since they both have the same blended sources and only differ
267 # over which model to use for the isolated sources.
268 isPseudo = getPseudoSources(outputCat, primaryConfig.pseudoFilterList, schema, setPrimaryTask.log)
269 self.assertEqual(
270 np.sum(outputCat["detect_isDeblendedSource"] & ~isPseudo),
271 np.sum(outputCat["detect_isDeblendedModelSource"]))
273 # Check that the sources contained in a tract are all marked appropriately
274 x = outputCat["slot_Centroid_x"]
275 y = outputCat["slot_Centroid_y"]
276 tractInner = tractBBox.contains(x, y)
277 np.testing.assert_array_equal(outputCat["detect_isTractInner"], tractInner)
279 # Check that the sources contained in a patch are all marked appropriately
280 patchInner = patchBBox.contains(x, y)
281 np.testing.assert_array_equal(outputCat["detect_isPatchInner"], patchInner)
283 # make sure all sky sources are flagged as not primary
284 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["merge_peak_sky"])), 0)
286 # Check that sky objects have not been deblended
287 np.testing.assert_array_equal(
288 isPseudo,
289 isPseudo & (outputCat["deblend_nChild"] == 0)
290 )
293class MemoryTester(lsst.utils.tests.MemoryTestCase):
294 pass
297def setup_module(module):
298 lsst.utils.tests.init()
301if __name__ == "__main__": 301 ↛ 302line 301 didn't jump to line 302, because the condition on line 301 was never true
302 lsst.utils.tests.init()
303 unittest.main()