Coverage for tests/test_isPrimaryFlag.py: 30%

157 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-07-19 05:40 -0700

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23import unittest 

24import numpy as np 

25import logging 

26 

27from lsst.geom import Point2I, Box2I, Extent2I 

28from lsst.skymap import TractInfo 

29from lsst.skymap.patchInfo import PatchInfo 

30import lsst.afw.image as afwImage 

31import lsst.utils.tests 

32from lsst.utils import getPackageDir 

33from lsst.pipe.tasks.characterizeImage import CharacterizeImageTask, CharacterizeImageConfig 

34from lsst.pipe.tasks.calibrate import CalibrateTask, CalibrateConfig 

35from lsst.meas.algorithms import SourceDetectionTask, SkyObjectsTask 

36from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask 

37from lsst.meas.base import SingleFrameMeasurementTask 

38from lsst.pipe.tasks.setPrimaryFlags import SetPrimaryFlagsTask, getPseudoSources 

39from lsst.afw.table import SourceCatalog 

40 

41 

42class NullTract(TractInfo): 

43 """A Tract not contained in the MockSkyMap. 

44 

45 BaseSkyMap.findTract(coord) will always return a Tract, 

46 even if the coord isn't located in the Tract. 

47 In order to mimick this functionality we create a 

48 NullTract for regions of the MockSkyMap that 

49 aren't contained in any of the tracts. 

50 """ 

51 def __init__(self): 

52 pass 

53 

54 def getId(self): 

55 return None 

56 

57 

58class MockTractInfo: 

59 """A Tract based on a bounding box and WCS. 

60 

61 Testing is made easier when we can specifically define 

62 a Tract in terms of its bounding box in pixel coordinates 

63 along with a WCS for the exposure. 

64 

65 Only the relevant methods from `TractInfo` needed to make 

66 test pass are implemented here. Since this is just for 

67 testing, it isn't sophisticated and requires developers to 

68 ensure that the size of the bounding box is evenly divisible 

69 by the number of patches in the Tract. 

70 """ 

71 def __init__(self, name, bbox, wcs, numPatches): 

72 self.name = name 

73 self.bbox = bbox 

74 self.wcs = wcs 

75 self._numPatches = numPatches 

76 assert bbox.getWidth()%numPatches[0] == 0 

77 assert bbox.getHeight()%numPatches[1] == 0 

78 self.patchWidth = bbox.getWidth()//numPatches[0] 

79 self.patchHeight = bbox.getHeight()//numPatches[1] 

80 

81 def contains(self, coord): 

82 pixel = self.wcs.skyToPixel(coord) 

83 return self.bbox.contains(Point2I(pixel)) 

84 

85 def getId(self): 

86 return self.name 

87 

88 def getNumPatches(self): 

89 return self._numPatches 

90 

91 def getPatchInfo(self, index): 

92 x, y = index 

93 width = self.patchWidth 

94 height = self.patchHeight 

95 

96 x = x*self.patchWidth 

97 y = y*self.patchHeight 

98 

99 bbox = Box2I(Point2I(x, y), Extent2I(width, height)) 

100 

101 nx, ny = self._numPatches 

102 sequentialIndex = nx*y + x 

103 

104 patchInfo = PatchInfo( 

105 index=index, 

106 innerBBox=bbox, 

107 outerBBox=bbox, 

108 sequentialIndex=sequentialIndex, 

109 tractWcs=self.wcs 

110 ) 

111 return patchInfo 

112 

113 def __getitem__(self, index): 

114 return self.getPatchInfo(index) 

115 

116 def __iter__(self): 

117 xNum, yNum = self.getNumPatches() 

118 for y in range(yNum): 

119 for x in range(xNum): 

120 yield self.getPatchInfo((x, y)) 

121 

122 

123class MockSkyMap: 

124 """A SkyMap based on a list of bounding boxes. 

125 

126 Testing is made easier when we can specifically define 

127 a Tract in terms of its bounding box in pixel coordinates 

128 along with a WCS for the exposure. This class allows us 

129 to define the tract(s) in the SkyMap and create 

130 them. 

131 """ 

132 def __init__(self, bboxes, wcs, numPatches): 

133 self.bboxes = bboxes 

134 self.wcs = wcs 

135 self.numPatches = numPatches 

136 

137 def __iter__(self): 

138 for b, bbox in enumerate(self.bboxes): 

139 yield self.generateTract(b) 

140 

141 def __getitem__(self, index): 

142 return self.generateTract(index) 

143 

144 def generateTract(self, index): 

145 return MockTractInfo(index, self.bboxes[index], self.wcs, self.numPatches) 

146 

147 def findTract(self, coord): 

148 for tractInfo in self: 

149 if tractInfo.contains(coord): 

150 return tractInfo 

151 

152 return NullTract() 

153 

154 

155class IsPrimaryTestCase(lsst.utils.tests.TestCase): 

156 

157 def setUp(self): 

158 # Load sample input from disk 

159 expPath = os.path.join(getPackageDir("pipe_tasks"), "tests", "data", "v695833-e0-c000-a00.sci.fits") 

160 self.exposure = afwImage.ExposureF(expPath) 

161 

162 # Characterize the image (create PSF, etc.) 

163 charImConfig = CharacterizeImageConfig() 

164 charImConfig.measurePsf.psfDeterminer = "piff" 

165 charImConfig.measurePsf.psfDeterminer["piff"].spatialOrder = 0 

166 charImTask = CharacterizeImageTask(config=charImConfig) 

167 self.charImResults = charImTask.run(self.exposure) 

168 

169 # set log level so that warnings do not display 

170 logging.getLogger("lsst.calibrate").setLevel(logging.ERROR) 

171 

172 def tearDown(self): 

173 del self.exposure 

174 self.charImResults 

175 

176 def testIsSinglePrimaryFlag(self): 

177 """Tests detect_isPrimary column gets added when run, and that sources 

178 labelled as detect_isPrimary are not sky sources and have no children. 

179 """ 

180 calibConfig = CalibrateConfig() 

181 calibConfig.doAstrometry = False 

182 calibConfig.doPhotoCal = False 

183 calibTask = CalibrateTask(config=calibConfig) 

184 calibResults = calibTask.run(self.charImResults.exposure) 

185 outputCat = calibResults.outputCat 

186 self.assertTrue("detect_isPrimary" in outputCat.schema.getNames()) 

187 # make sure all sky sources are flagged as not primary 

188 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["sky_source"])), 0) 

189 # make sure all parent sources are flagged as not primary 

190 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["deblend_nChild"] > 0)), 0) 

191 

192 with self.assertRaises(KeyError): 

193 outputCat.getSchema().find("detect_isDelendedModelPrimary") 

194 

195 def testIsScarletPrimaryFlag(self): 

196 """Test detect_isPrimary column when scarlet is used as the deblender 

197 """ 

198 # We need a multiband coadd for scarlet, 

199 # even though there is only one band 

200 coadds = afwImage.MultibandExposure.fromExposures(["test"], [self.exposure]) 

201 

202 # Create a SkyMap with a tract that contains a portion of the image, 

203 # subdivided into 3x3 patches 

204 wcs = self.exposure.getWcs() 

205 tractBBox = Box2I(Point2I(100, 100), Extent2I(900, 900)) 

206 skyMap = MockSkyMap([tractBBox], wcs, (3, 3)) 

207 tractInfo = skyMap[0] 

208 patchInfo = tractInfo[0, 0] 

209 patchBBox = patchInfo.getInnerBBox() 

210 

211 schema = SourceCatalog.Table.makeMinimalSchema() 

212 # Initialize the detection task 

213 detectionTask = SourceDetectionTask(schema=schema) 

214 

215 # Initialize the fake source injection task 

216 skyConfig = SkyObjectsTask.ConfigClass() 

217 skySourcesTask = SkyObjectsTask(name="skySources", config=skyConfig) 

218 schema.addField("merge_peak_sky", type="Flag") 

219 

220 # Initialize the deblender task 

221 scarletConfig = ScarletDeblendTask.ConfigClass() 

222 scarletConfig.maxIter = 20 

223 scarletConfig.columnInheritance["merge_peak_sky"] = "merge_peak_sky" 

224 deblendTask = ScarletDeblendTask(schema=schema, config=scarletConfig) 

225 

226 # We'll customize the configuration of measurement to just run the 

227 # minimal number of plugins to make setPrimaryFlags work. 

228 measureConfig = SingleFrameMeasurementTask.ConfigClass() 

229 measureConfig.plugins.names = ["base_SdssCentroid", "base_SkyCoord"] 

230 measureConfig.slots.psfFlux = None 

231 measureConfig.slots.apFlux = None 

232 measureConfig.slots.shape = None 

233 measureConfig.slots.modelFlux = None 

234 measureConfig.slots.calibFlux = None 

235 measureConfig.slots.gaussianFlux = None 

236 measureTask = SingleFrameMeasurementTask(config=measureConfig, schema=schema) 

237 primaryConfig = SetPrimaryFlagsTask.ConfigClass() 

238 setPrimaryTask = SetPrimaryFlagsTask(config=primaryConfig, schema=schema, 

239 name="setPrimaryFlags", isSingleFrame=False) 

240 

241 table = SourceCatalog.Table.make(schema) 

242 # detect sources 

243 detectionResult = detectionTask.run(table, coadds["test"]) 

244 catalog = detectionResult.sources 

245 # add fake sources 

246 skySources = skySourcesTask.run(mask=self.exposure.mask, seed=0) 

247 for foot in skySources[:5]: 

248 src = catalog.addNew() 

249 src.setFootprint(foot) 

250 src.set("merge_peak_sky", True) 

251 # deblend 

252 catalog, modelData = deblendTask.run(coadds, catalog) 

253 # Attach footprints to the catalog 

254 modelData.updateCatalogFootprints( 

255 catalog=catalog, 

256 band="test", 

257 psfModel=coadds["test"].getPsf(), 

258 redistributeImage=None, 

259 ) 

260 # measure 

261 measureTask.run(catalog, self.exposure) 

262 outputCat = catalog 

263 # Set the primary flags 

264 setPrimaryTask.run(outputCat, skyMap=skyMap, tractInfo=tractInfo, patchInfo=patchInfo) 

265 

266 # There should be the same number of deblenedPrimary and 

267 # deblendedModelPrimary sources, 

268 # since they both have the same blended sources and only differ 

269 # over which model to use for the isolated sources. 

270 isPseudo = getPseudoSources(outputCat, primaryConfig.pseudoFilterList, schema, setPrimaryTask.log) 

271 self.assertEqual( 

272 np.sum(outputCat["detect_isDeblendedSource"] & ~isPseudo), 

273 np.sum(outputCat["detect_isDeblendedModelSource"])) 

274 

275 # Check that the sources contained in a tract are all marked appropriately 

276 x = outputCat["slot_Centroid_x"] 

277 y = outputCat["slot_Centroid_y"] 

278 tractInner = tractBBox.contains(x, y) 

279 np.testing.assert_array_equal(outputCat["detect_isTractInner"], tractInner) 

280 

281 # Check that the sources contained in a patch are all marked appropriately 

282 patchInner = patchBBox.contains(x, y) 

283 np.testing.assert_array_equal(outputCat["detect_isPatchInner"], patchInner) 

284 

285 # make sure all sky sources are flagged as not primary 

286 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["merge_peak_sky"])), 0) 

287 

288 # Check that sky objects have not been deblended 

289 np.testing.assert_array_equal( 

290 isPseudo, 

291 isPseudo & (outputCat["deblend_nChild"] == 0) 

292 ) 

293 

294 

295class MemoryTester(lsst.utils.tests.MemoryTestCase): 

296 pass 

297 

298 

299def setup_module(module): 

300 lsst.utils.tests.init() 

301 

302 

303if __name__ == "__main__": 303 ↛ 304line 303 didn't jump to line 304, because the condition on line 303 was never true

304 lsst.utils.tests.init() 

305 unittest.main()