Coverage for tests/test_isPrimaryFlag.py: 26%

158 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-13 12:19 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23import unittest 

24import numpy as np 

25import logging 

26 

27from lsst.geom import Point2I, Box2I, Extent2I 

28from lsst.skymap import TractInfo 

29from lsst.skymap.patchInfo import PatchInfo 

30import lsst.afw.image as afwImage 

31import lsst.utils.tests 

32from lsst.utils import getPackageDir 

33from lsst.pipe.tasks.characterizeImage import CharacterizeImageTask, CharacterizeImageConfig 

34from lsst.pipe.tasks.calibrate import CalibrateTask, CalibrateConfig 

35from lsst.meas.algorithms import SourceDetectionTask, SkyObjectsTask 

36import lsst.meas.extensions.scarlet as mes 

37from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask 

38from lsst.meas.base import SingleFrameMeasurementTask 

39from lsst.pipe.tasks.setPrimaryFlags import SetPrimaryFlagsTask, getPseudoSources 

40from lsst.afw.table import SourceCatalog 

41 

42 

43class NullTract(TractInfo): 

44 """A Tract not contained in the MockSkyMap. 

45 

46 BaseSkyMap.findTract(coord) will always return a Tract, 

47 even if the coord isn't located in the Tract. 

48 In order to mimick this functionality we create a 

49 NullTract for regions of the MockSkyMap that 

50 aren't contained in any of the tracts. 

51 """ 

52 def __init__(self): 

53 pass 

54 

55 def getId(self): 

56 return None 

57 

58 

59class MockTractInfo: 

60 """A Tract based on a bounding box and WCS. 

61 

62 Testing is made easier when we can specifically define 

63 a Tract in terms of its bounding box in pixel coordinates 

64 along with a WCS for the exposure. 

65 

66 Only the relevant methods from `TractInfo` needed to make 

67 test pass are implemented here. Since this is just for 

68 testing, it isn't sophisticated and requires developers to 

69 ensure that the size of the bounding box is evenly divisible 

70 by the number of patches in the Tract. 

71 """ 

72 def __init__(self, name, bbox, wcs, numPatches): 

73 self.name = name 

74 self.bbox = bbox 

75 self.wcs = wcs 

76 self._numPatches = numPatches 

77 assert bbox.getWidth()%numPatches[0] == 0 

78 assert bbox.getHeight()%numPatches[1] == 0 

79 self.patchWidth = bbox.getWidth()//numPatches[0] 

80 self.patchHeight = bbox.getHeight()//numPatches[1] 

81 

82 def contains(self, coord): 

83 pixel = self.wcs.skyToPixel(coord) 

84 return self.bbox.contains(Point2I(pixel)) 

85 

86 def getId(self): 

87 return self.name 

88 

89 def getNumPatches(self): 

90 return self._numPatches 

91 

92 def getPatchInfo(self, index): 

93 x, y = index 

94 width = self.patchWidth 

95 height = self.patchHeight 

96 

97 x = x*self.patchWidth 

98 y = y*self.patchHeight 

99 

100 bbox = Box2I(Point2I(x, y), Extent2I(width, height)) 

101 

102 nx, ny = self._numPatches 

103 sequentialIndex = nx*y + x 

104 

105 patchInfo = PatchInfo( 

106 index=index, 

107 innerBBox=bbox, 

108 outerBBox=bbox, 

109 sequentialIndex=sequentialIndex, 

110 tractWcs=self.wcs 

111 ) 

112 return patchInfo 

113 

114 def __getitem__(self, index): 

115 return self.getPatchInfo(index) 

116 

117 def __iter__(self): 

118 xNum, yNum = self.getNumPatches() 

119 for y in range(yNum): 

120 for x in range(xNum): 

121 yield self.getPatchInfo((x, y)) 

122 

123 

124class MockSkyMap: 

125 """A SkyMap based on a list of bounding boxes. 

126 

127 Testing is made easier when we can specifically define 

128 a Tract in terms of its bounding box in pixel coordinates 

129 along with a WCS for the exposure. This class allows us 

130 to define the tract(s) in the SkyMap and create 

131 them. 

132 """ 

133 def __init__(self, bboxes, wcs, numPatches): 

134 self.bboxes = bboxes 

135 self.wcs = wcs 

136 self.numPatches = numPatches 

137 

138 def __iter__(self): 

139 for b, bbox in enumerate(self.bboxes): 

140 yield self.generateTract(b) 

141 

142 def __getitem__(self, index): 

143 return self.generateTract(index) 

144 

145 def generateTract(self, index): 

146 return MockTractInfo(index, self.bboxes[index], self.wcs, self.numPatches) 

147 

148 def findTract(self, coord): 

149 for tractInfo in self: 

150 if tractInfo.contains(coord): 

151 return tractInfo 

152 

153 return NullTract() 

154 

155 

156class IsPrimaryTestCase(lsst.utils.tests.TestCase): 

157 

158 def setUp(self): 

159 # Load sample input from disk 

160 expPath = os.path.join(getPackageDir("pipe_tasks"), "tests", "data", "v695833-e0-c000-a00.sci.fits") 

161 self.exposure = afwImage.ExposureF(expPath) 

162 

163 # Characterize the image (create PSF, etc.) 

164 charImConfig = CharacterizeImageConfig() 

165 charImConfig.measureApCorr.sourceSelector["science"].doSignalToNoise = False 

166 charImTask = CharacterizeImageTask(config=charImConfig) 

167 self.charImResults = charImTask.run(self.exposure) 

168 

169 # set log level so that warnings do not display 

170 logging.getLogger("lsst.calibrate").setLevel(logging.ERROR) 

171 

172 def tearDown(self): 

173 del self.exposure 

174 self.charImResults 

175 

176 def testIsSinglePrimaryFlag(self): 

177 """Tests detect_isPrimary column gets added when run, and that sources 

178 labelled as detect_isPrimary are not sky sources and have no children. 

179 """ 

180 calibConfig = CalibrateConfig() 

181 calibConfig.doAstrometry = False 

182 calibConfig.doPhotoCal = False 

183 calibConfig.doComputeSummaryStats = False 

184 calibTask = CalibrateTask(config=calibConfig) 

185 calibResults = calibTask.run(self.charImResults.exposure) 

186 outputCat = calibResults.outputCat 

187 self.assertTrue("detect_isPrimary" in outputCat.schema.getNames()) 

188 # make sure all sky sources are flagged as not primary 

189 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["sky_source"])), 0) 

190 # make sure all parent sources are flagged as not primary 

191 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["deblend_nChild"] > 0)), 0) 

192 

193 with self.assertRaises(KeyError): 

194 outputCat.getSchema().find("detect_isDelendedModelPrimary") 

195 

196 def testIsScarletPrimaryFlag(self): 

197 """Test detect_isPrimary column when scarlet is used as the deblender 

198 """ 

199 # We need a multiband coadd for scarlet, 

200 # even though there is only one band 

201 coadds = afwImage.MultibandExposure.fromExposures(["test"], [self.exposure]) 

202 

203 # Create a SkyMap with a tract that contains a portion of the image, 

204 # subdivided into 3x3 patches 

205 wcs = self.exposure.getWcs() 

206 tractBBox = Box2I(Point2I(100, 100), Extent2I(900, 900)) 

207 skyMap = MockSkyMap([tractBBox], wcs, (3, 3)) 

208 tractInfo = skyMap[0] 

209 patchInfo = tractInfo[0, 0] 

210 patchBBox = patchInfo.getInnerBBox() 

211 

212 schema = SourceCatalog.Table.makeMinimalSchema() 

213 # Initialize the detection task 

214 detectionTask = SourceDetectionTask(schema=schema) 

215 

216 # Initialize the fake source injection task 

217 skyConfig = SkyObjectsTask.ConfigClass() 

218 skySourcesTask = SkyObjectsTask(name="skySources", config=skyConfig) 

219 schema.addField("merge_peak_sky", type="Flag") 

220 

221 # Initialize the deblender task 

222 scarletConfig = ScarletDeblendTask.ConfigClass() 

223 scarletConfig.maxIter = 20 

224 scarletConfig.columnInheritance["merge_peak_sky"] = "merge_peak_sky" 

225 deblendTask = ScarletDeblendTask(schema=schema, config=scarletConfig) 

226 

227 # We'll customize the configuration of measurement to just run the 

228 # minimal number of plugins to make setPrimaryFlags work. 

229 measureConfig = SingleFrameMeasurementTask.ConfigClass() 

230 measureConfig.plugins.names = ["base_SdssCentroid", "base_SkyCoord"] 

231 measureConfig.slots.psfFlux = None 

232 measureConfig.slots.apFlux = None 

233 measureConfig.slots.shape = None 

234 measureConfig.slots.modelFlux = None 

235 measureConfig.slots.calibFlux = None 

236 measureConfig.slots.gaussianFlux = None 

237 measureTask = SingleFrameMeasurementTask(config=measureConfig, schema=schema) 

238 primaryConfig = SetPrimaryFlagsTask.ConfigClass() 

239 setPrimaryTask = SetPrimaryFlagsTask(config=primaryConfig, schema=schema, 

240 name="setPrimaryFlags", isSingleFrame=False) 

241 

242 table = SourceCatalog.Table.make(schema) 

243 # detect sources 

244 detectionResult = detectionTask.run(table, coadds["test"]) 

245 catalog = detectionResult.sources 

246 # add fake sources 

247 skySources = skySourcesTask.run(mask=self.exposure.mask, seed=0) 

248 for foot in skySources[:5]: 

249 src = catalog.addNew() 

250 src.setFootprint(foot) 

251 src.set("merge_peak_sky", True) 

252 # deblend 

253 catalog, modelData = deblendTask.run(coadds, catalog) 

254 # Attach footprints to the catalog 

255 mes.io.updateCatalogFootprints( 

256 modelData=modelData, 

257 catalog=catalog, 

258 band="test", 

259 imageForRedistribution=coadds["test"], 

260 removeScarletData=True, 

261 updateFluxColumns=True, 

262 ) 

263 # measure 

264 measureTask.run(catalog, self.exposure) 

265 outputCat = catalog 

266 # Set the primary flags 

267 setPrimaryTask.run(outputCat, skyMap=skyMap, tractInfo=tractInfo, patchInfo=patchInfo) 

268 

269 # There should be the same number of deblenedPrimary and 

270 # deblendedModelPrimary sources, 

271 # since they both have the same blended sources and only differ 

272 # over which model to use for the isolated sources. 

273 isPseudo = getPseudoSources(outputCat, primaryConfig.pseudoFilterList, schema, setPrimaryTask.log) 

274 self.assertEqual( 

275 np.sum(outputCat["detect_isDeblendedSource"] & ~isPseudo), 

276 np.sum(outputCat["detect_isDeblendedModelSource"])) 

277 

278 # Check that the sources contained in a tract are all marked appropriately 

279 x = outputCat["slot_Centroid_x"] 

280 y = outputCat["slot_Centroid_y"] 

281 tractInner = tractBBox.contains(x, y) 

282 np.testing.assert_array_equal(outputCat["detect_isTractInner"], tractInner) 

283 

284 # Check that the sources contained in a patch are all marked appropriately 

285 patchInner = patchBBox.contains(x, y) 

286 np.testing.assert_array_equal(outputCat["detect_isPatchInner"], patchInner) 

287 

288 # make sure all sky sources are flagged as not primary 

289 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["merge_peak_sky"])), 0) 

290 

291 # Check that sky objects have not been deblended 

292 np.testing.assert_array_equal( 

293 isPseudo, 

294 isPseudo & (outputCat["deblend_nChild"] == 0) 

295 ) 

296 

297 

298class MemoryTester(lsst.utils.tests.MemoryTestCase): 

299 pass 

300 

301 

302def setup_module(module): 

303 lsst.utils.tests.init() 

304 

305 

306if __name__ == "__main__": 306 ↛ 307line 306 didn't jump to line 307, because the condition on line 306 was never true

307 lsst.utils.tests.init() 

308 unittest.main()