Coverage for tests/test_isPrimaryFlag.py: 30%
157 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-29 03:24 -0700
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-29 03:24 -0700
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
23import unittest
24import numpy as np
25import logging
27from lsst.geom import Point2I, Box2I, Extent2I
28from lsst.skymap import TractInfo
29from lsst.skymap.patchInfo import PatchInfo
30import lsst.afw.image as afwImage
31import lsst.utils.tests
32from lsst.utils import getPackageDir
33from lsst.pipe.tasks.characterizeImage import CharacterizeImageTask, CharacterizeImageConfig
34from lsst.pipe.tasks.calibrate import CalibrateTask, CalibrateConfig
35from lsst.meas.algorithms import SourceDetectionTask, SkyObjectsTask
36from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask
37from lsst.meas.base import SingleFrameMeasurementTask
38from lsst.pipe.tasks.setPrimaryFlags import SetPrimaryFlagsTask, getPseudoSources
39from lsst.afw.table import SourceCatalog
42class NullTract(TractInfo):
43 """A Tract not contained in the MockSkyMap.
45 BaseSkyMap.findTract(coord) will always return a Tract,
46 even if the coord isn't located in the Tract.
47 In order to mimick this functionality we create a
48 NullTract for regions of the MockSkyMap that
49 aren't contained in any of the tracts.
50 """
51 def __init__(self):
52 pass
54 def getId(self):
55 return None
58class MockTractInfo:
59 """A Tract based on a bounding box and WCS.
61 Testing is made easier when we can specifically define
62 a Tract in terms of its bounding box in pixel coordinates
63 along with a WCS for the exposure.
65 Only the relevant methods from `TractInfo` needed to make
66 test pass are implemented here. Since this is just for
67 testing, it isn't sophisticated and requires developers to
68 ensure that the size of the bounding box is evenly divisible
69 by the number of patches in the Tract.
70 """
71 def __init__(self, name, bbox, wcs, numPatches):
72 self.name = name
73 self.bbox = bbox
74 self.wcs = wcs
75 self._numPatches = numPatches
76 assert bbox.getWidth()%numPatches[0] == 0
77 assert bbox.getHeight()%numPatches[1] == 0
78 self.patchWidth = bbox.getWidth()//numPatches[0]
79 self.patchHeight = bbox.getHeight()//numPatches[1]
81 def contains(self, coord):
82 pixel = self.wcs.skyToPixel(coord)
83 return self.bbox.contains(Point2I(pixel))
85 def getId(self):
86 return self.name
88 def getNumPatches(self):
89 return self._numPatches
91 def getPatchInfo(self, index):
92 x, y = index
93 width = self.patchWidth
94 height = self.patchHeight
96 x = x*self.patchWidth
97 y = y*self.patchHeight
99 bbox = Box2I(Point2I(x, y), Extent2I(width, height))
101 nx, ny = self._numPatches
102 sequentialIndex = nx*y + x
104 patchInfo = PatchInfo(
105 index=index,
106 innerBBox=bbox,
107 outerBBox=bbox,
108 sequentialIndex=sequentialIndex,
109 tractWcs=self.wcs
110 )
111 return patchInfo
113 def __getitem__(self, index):
114 return self.getPatchInfo(index)
116 def __iter__(self):
117 xNum, yNum = self.getNumPatches()
118 for y in range(yNum):
119 for x in range(xNum):
120 yield self.getPatchInfo((x, y))
123class MockSkyMap:
124 """A SkyMap based on a list of bounding boxes.
126 Testing is made easier when we can specifically define
127 a Tract in terms of its bounding box in pixel coordinates
128 along with a WCS for the exposure. This class allows us
129 to define the tract(s) in the SkyMap and create
130 them.
131 """
132 def __init__(self, bboxes, wcs, numPatches):
133 self.bboxes = bboxes
134 self.wcs = wcs
135 self.numPatches = numPatches
137 def __iter__(self):
138 for b, bbox in enumerate(self.bboxes):
139 yield self.generateTract(b)
141 def __getitem__(self, index):
142 return self.generateTract(index)
144 def generateTract(self, index):
145 return MockTractInfo(index, self.bboxes[index], self.wcs, self.numPatches)
147 def findTract(self, coord):
148 for tractInfo in self:
149 if tractInfo.contains(coord):
150 return tractInfo
152 return NullTract()
155class IsPrimaryTestCase(lsst.utils.tests.TestCase):
157 def setUp(self):
158 # Load sample input from disk
159 expPath = os.path.join(getPackageDir("pipe_tasks"), "tests", "data", "v695833-e0-c000-a00.sci.fits")
160 self.exposure = afwImage.ExposureF(expPath)
162 # Characterize the image (create PSF, etc.)
163 charImConfig = CharacterizeImageConfig()
164 charImConfig.measurePsf.psfDeterminer = "piff"
165 charImConfig.measurePsf.psfDeterminer["piff"].spatialOrder = 0
166 charImTask = CharacterizeImageTask(config=charImConfig)
167 self.charImResults = charImTask.run(self.exposure)
169 # set log level so that warnings do not display
170 logging.getLogger("lsst.calibrate").setLevel(logging.ERROR)
172 def tearDown(self):
173 del self.exposure
174 self.charImResults
176 def testIsSinglePrimaryFlag(self):
177 """Tests detect_isPrimary column gets added when run, and that sources
178 labelled as detect_isPrimary are not sky sources and have no children.
179 """
180 calibConfig = CalibrateConfig()
181 calibConfig.doAstrometry = False
182 calibConfig.doPhotoCal = False
183 calibTask = CalibrateTask(config=calibConfig)
184 calibResults = calibTask.run(self.charImResults.exposure)
185 outputCat = calibResults.outputCat
186 self.assertTrue("detect_isPrimary" in outputCat.schema.getNames())
187 # make sure all sky sources are flagged as not primary
188 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["sky_source"])), 0)
189 # make sure all parent sources are flagged as not primary
190 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["deblend_nChild"] > 0)), 0)
192 with self.assertRaises(KeyError):
193 outputCat.getSchema().find("detect_isDelendedModelPrimary")
195 def testIsScarletPrimaryFlag(self):
196 """Test detect_isPrimary column when scarlet is used as the deblender
197 """
198 # We need a multiband coadd for scarlet,
199 # even though there is only one band
200 coadds = afwImage.MultibandExposure.fromExposures(["test"], [self.exposure])
202 # Create a SkyMap with a tract that contains a portion of the image,
203 # subdivided into 3x3 patches
204 wcs = self.exposure.getWcs()
205 tractBBox = Box2I(Point2I(100, 100), Extent2I(900, 900))
206 skyMap = MockSkyMap([tractBBox], wcs, (3, 3))
207 tractInfo = skyMap[0]
208 patchInfo = tractInfo[0, 0]
209 patchBBox = patchInfo.getInnerBBox()
211 schema = SourceCatalog.Table.makeMinimalSchema()
212 # Initialize the detection task
213 detectionTask = SourceDetectionTask(schema=schema)
215 # Initialize the fake source injection task
216 skyConfig = SkyObjectsTask.ConfigClass()
217 skySourcesTask = SkyObjectsTask(name="skySources", config=skyConfig)
218 schema.addField("merge_peak_sky", type="Flag")
220 # Initialize the deblender task
221 scarletConfig = ScarletDeblendTask.ConfigClass()
222 scarletConfig.maxIter = 20
223 scarletConfig.columnInheritance["merge_peak_sky"] = "merge_peak_sky"
224 deblendTask = ScarletDeblendTask(schema=schema, config=scarletConfig)
226 # We'll customize the configuration of measurement to just run the
227 # minimal number of plugins to make setPrimaryFlags work.
228 measureConfig = SingleFrameMeasurementTask.ConfigClass()
229 measureConfig.plugins.names = ["base_SdssCentroid", "base_SkyCoord"]
230 measureConfig.slots.psfFlux = None
231 measureConfig.slots.apFlux = None
232 measureConfig.slots.shape = None
233 measureConfig.slots.modelFlux = None
234 measureConfig.slots.calibFlux = None
235 measureConfig.slots.gaussianFlux = None
236 measureTask = SingleFrameMeasurementTask(config=measureConfig, schema=schema)
237 primaryConfig = SetPrimaryFlagsTask.ConfigClass()
238 setPrimaryTask = SetPrimaryFlagsTask(config=primaryConfig, schema=schema,
239 name="setPrimaryFlags", isSingleFrame=False)
241 table = SourceCatalog.Table.make(schema)
242 # detect sources
243 detectionResult = detectionTask.run(table, coadds["test"])
244 catalog = detectionResult.sources
245 # add fake sources
246 skySources = skySourcesTask.run(mask=self.exposure.mask, seed=0)
247 for foot in skySources[:5]:
248 src = catalog.addNew()
249 src.setFootprint(foot)
250 src.set("merge_peak_sky", True)
251 # deblend
252 catalog, modelData = deblendTask.run(coadds, catalog)
253 # Attach footprints to the catalog
254 modelData.updateCatalogFootprints(
255 catalog=catalog,
256 band="test",
257 psfModel=coadds["test"].getPsf(),
258 redistributeImage=None,
259 )
260 # measure
261 measureTask.run(catalog, self.exposure)
262 outputCat = catalog
263 # Set the primary flags
264 setPrimaryTask.run(outputCat, skyMap=skyMap, tractInfo=tractInfo, patchInfo=patchInfo)
266 # There should be the same number of deblenedPrimary and
267 # deblendedModelPrimary sources,
268 # since they both have the same blended sources and only differ
269 # over which model to use for the isolated sources.
270 isPseudo = getPseudoSources(outputCat, primaryConfig.pseudoFilterList, schema, setPrimaryTask.log)
271 self.assertEqual(
272 np.sum(outputCat["detect_isDeblendedSource"] & ~isPseudo),
273 np.sum(outputCat["detect_isDeblendedModelSource"]))
275 # Check that the sources contained in a tract are all marked appropriately
276 x = outputCat["slot_Centroid_x"]
277 y = outputCat["slot_Centroid_y"]
278 tractInner = tractBBox.contains(x, y)
279 np.testing.assert_array_equal(outputCat["detect_isTractInner"], tractInner)
281 # Check that the sources contained in a patch are all marked appropriately
282 patchInner = patchBBox.contains(x, y)
283 np.testing.assert_array_equal(outputCat["detect_isPatchInner"], patchInner)
285 # make sure all sky sources are flagged as not primary
286 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["merge_peak_sky"])), 0)
288 # Check that sky objects have not been deblended
289 np.testing.assert_array_equal(
290 isPseudo,
291 isPseudo & (outputCat["deblend_nChild"] == 0)
292 )
295class MemoryTester(lsst.utils.tests.MemoryTestCase):
296 pass
299def setup_module(module):
300 lsst.utils.tests.init()
303if __name__ == "__main__": 303 ↛ 304line 303 didn't jump to line 304, because the condition on line 303 was never true
304 lsst.utils.tests.init()
305 unittest.main()