Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23import unittest 

24import numpy as np 

25import logging 

26 

27from lsst.geom import Point2I, Box2I, Extent2I 

28from lsst.skymap import TractInfo 

29from lsst.skymap.patchInfo import PatchInfo 

30import lsst.afw.image as afwImage 

31import lsst.utils.tests 

32from lsst.utils import getPackageDir 

33from lsst.pipe.tasks.characterizeImage import CharacterizeImageTask, CharacterizeImageConfig 

34from lsst.pipe.tasks.calibrate import CalibrateTask, CalibrateConfig 

35from lsst.meas.algorithms import SourceDetectionTask, SkyObjectsTask 

36from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask 

37from lsst.meas.base import SingleFrameMeasurementTask 

38from lsst.pipe.tasks.setPrimaryFlags import SetPrimaryFlagsTask, getPseudoSources 

39from lsst.afw.table import SourceCatalog 

40 

41 

42class NullTract(TractInfo): 

43 """A Tract not contained in the MockSkyMap. 

44 

45 BaseSkyMap.findTract(coord) will always return a Tract, 

46 even if the coord isn't located in the Tract. 

47 In order to mimick this functionality we create a 

48 NullTract for regions of the MockSkyMap that 

49 aren't contained in any of the tracts. 

50 """ 

51 def __init__(self): 

52 pass 

53 

54 def getId(self): 

55 return None 

56 

57 

58class MockTractInfo: 

59 """A Tract based on a bounding box and WCS. 

60 

61 Testing is made easier when we can specifically define 

62 a Tract in terms of its bounding box in pixel coordinates 

63 along with a WCS for the exposure. 

64 

65 Only the relevant methods from `TractInfo` needed to make 

66 test pass are implemented here. Since this is just for 

67 testing, it isn't sophisticated and requires developers to 

68 ensure that the size of the bounding box is evenly divisible 

69 by the number of patches in the Tract. 

70 """ 

71 def __init__(self, name, bbox, wcs, numPatches): 

72 self.name = name 

73 self.bbox = bbox 

74 self.wcs = wcs 

75 self.numPatchs = numPatches 

76 assert bbox.getWidth()%numPatches == 0 

77 assert bbox.getHeight()%numPatches == 0 

78 self.patchWidth = bbox.getWidth()//numPatches 

79 self.patchHeight = bbox.getHeight()//numPatches 

80 

81 def contains(self, coord): 

82 pixel = self.wcs.skyToPixel(coord) 

83 return self.bbox.contains(Point2I(pixel)) 

84 

85 def getId(self): 

86 return self.name 

87 

88 def getNumPatches(self): 

89 return self._numPatches 

90 

91 def getPatchInfo(self, index): 

92 x, y = index 

93 width = self.patchWidth 

94 height = self.patchHeight 

95 

96 x = x*self.patchWidth 

97 y = y*self.patchHeight 

98 

99 bbox = Box2I(Point2I(x, y), Extent2I(width, height)) 

100 

101 patchInfo = PatchInfo( 

102 index=index, 

103 innerBBox=bbox, 

104 outerBBox=bbox, 

105 ) 

106 return patchInfo 

107 

108 def __getitem__(self, index): 

109 return self.getPatchInfo(index) 

110 

111 def __iter__(self): 

112 xNum, yNum = self.getNumPatches() 

113 for y in range(yNum): 

114 for x in range(xNum): 

115 yield self.getPatchInfo((x, y)) 

116 

117 

118class MockSkyMap: 

119 """A SkyMap based on a list of bounding boxes. 

120 

121 Testing is made easier when we can specifically define 

122 a Tract in terms of its bounding box in pixel coordinates 

123 along with a WCS for the exposure. This class allows us 

124 to define the tract(s) in the SkyMap and create 

125 them. 

126 """ 

127 def __init__(self, bboxes, wcs, numPatches): 

128 self.bboxes = bboxes 

129 self.wcs = wcs 

130 self.numPatches = numPatches 

131 

132 def __iter__(self): 

133 for b, bbox in enumerate(self.bboxes): 

134 yield self.generateTract(b) 

135 

136 def __getitem__(self, index): 

137 return self.generateTract(index) 

138 

139 def generateTract(self, index): 

140 return MockTractInfo(index, self.bboxes[index], self.wcs, self.numPatches) 

141 

142 def findTract(self, coord): 

143 for tractInfo in self: 

144 if tractInfo.contains(coord): 

145 return tractInfo 

146 

147 return NullTract() 

148 

149 

150class IsPrimaryTestCase(lsst.utils.tests.TestCase): 

151 

152 def setUp(self): 

153 # Load sample input from disk 

154 expPath = os.path.join(getPackageDir("pipe_tasks"), "tests", "data", "v695833-e0-c000-a00.sci.fits") 

155 self.exposure = afwImage.ExposureF(expPath) 

156 

157 # Characterize the image (create PSF, etc.) 

158 charImConfig = CharacterizeImageConfig() 

159 charImTask = CharacterizeImageTask(config=charImConfig) 

160 self.charImResults = charImTask.run(self.exposure) 

161 

162 # set log level so that warnings do not display 

163 logging.getLogger("calibrate").setLevel(logging.ERROR) 

164 

165 def tearDown(self): 

166 del self.exposure 

167 self.charImResults 

168 

169 def testIsSinglePrimaryFlag(self): 

170 """Tests detect_isPrimary column gets added when run, and that sources 

171 labelled as detect_isPrimary are not sky sources and have no children. 

172 """ 

173 calibConfig = CalibrateConfig() 

174 calibConfig.doAstrometry = False 

175 calibConfig.doPhotoCal = False 

176 calibTask = CalibrateTask(config=calibConfig) 

177 calibResults = calibTask.run(self.charImResults.exposure) 

178 outputCat = calibResults.outputCat 

179 self.assertTrue("detect_isPrimary" in outputCat.schema.getNames()) 

180 # make sure all sky sources are flagged as not primary 

181 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["sky_source"])), 0) 

182 # make sure all parent sources are flagged as not primary 

183 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["deblend_nChild"] > 0)), 0) 

184 

185 with self.assertRaises(KeyError): 

186 outputCat.getSchema().find("detect_isDelendedModelPrimary") 

187 

188 def testIsScarletPrimaryFlag(self): 

189 """Test detect_isPrimary column when scarlet is used as the deblender 

190 """ 

191 # We need a multiband coadd for scarlet, 

192 # even though there is only one band 

193 coadds = afwImage.MultibandExposure.fromExposures(["test"], [self.exposure]) 

194 

195 # Create a SkyMap with a tract that contains a portion of the image, 

196 # subdivided into 3x3 patches 

197 wcs = self.exposure.getWcs() 

198 tractBBox = Box2I(Point2I(100, 100), Extent2I(900, 900)) 

199 skyMap = MockSkyMap([tractBBox], wcs, 3) 

200 tractInfo = skyMap[0] 

201 patchInfo = tractInfo[0, 0] 

202 patchBBox = patchInfo.getInnerBBox() 

203 

204 schema = SourceCatalog.Table.makeMinimalSchema() 

205 # Initialize the detection task 

206 detectionTask = SourceDetectionTask(schema=schema) 

207 

208 # Initialize the fake source injection task 

209 skyConfig = SkyObjectsTask.ConfigClass() 

210 skySourcesTask = SkyObjectsTask(name="skySources", config=skyConfig) 

211 schema.addField("merge_peak_sky", type="Flag") 

212 

213 # Initialize the deblender task 

214 scarletConfig = ScarletDeblendTask.ConfigClass() 

215 scarletConfig.maxIter = 20 

216 scarletConfig.columnInheritance["merge_peak_sky"] = "merge_peak_sky" 

217 deblendTask = ScarletDeblendTask(schema=schema, config=scarletConfig) 

218 

219 # We'll customize the configuration of measurement to just run the 

220 # minimal number of plugins to make setPrimaryFlags work. 

221 measureConfig = SingleFrameMeasurementTask.ConfigClass() 

222 measureConfig.plugins.names = ["base_SdssCentroid", "base_SkyCoord"] 

223 measureConfig.slots.psfFlux = None 

224 measureConfig.slots.apFlux = None 

225 measureConfig.slots.shape = None 

226 measureConfig.slots.modelFlux = None 

227 measureConfig.slots.calibFlux = None 

228 measureConfig.slots.gaussianFlux = None 

229 measureTask = SingleFrameMeasurementTask(config=measureConfig, schema=schema) 

230 primaryConfig = SetPrimaryFlagsTask.ConfigClass() 

231 setPrimaryTask = SetPrimaryFlagsTask(config=primaryConfig, schema=schema, 

232 name="setPrimaryFlags", isSingleFrame=False) 

233 

234 table = SourceCatalog.Table.make(schema) 

235 # detect sources 

236 detectionResult = detectionTask.run(table, coadds["test"]) 

237 catalog = detectionResult.sources 

238 # add fake sources 

239 skySources = skySourcesTask.run(mask=self.exposure.mask, seed=0) 

240 for foot in skySources[:5]: 

241 src = catalog.addNew() 

242 src.setFootprint(foot) 

243 src.set("merge_peak_sky", True) 

244 # deblend 

245 result = deblendTask.run(coadds, catalog) 

246 # measure 

247 measureTask.run(result["test"], self.exposure) 

248 outputCat = result["test"] 

249 # Set the primary flags 

250 setPrimaryTask.run(outputCat, skyMap=skyMap, tractInfo=tractInfo, patchInfo=patchInfo) 

251 

252 # There should be the same number of deblenedPrimary and 

253 # deblendedModelPrimary sources, 

254 # since they both have the same blended sources and only differ 

255 # over which model to use for the isolated sources. 

256 isPseudo = getPseudoSources(outputCat, primaryConfig.pseudoFilterList, schema, setPrimaryTask.log) 

257 self.assertEqual( 

258 np.sum(outputCat["detect_isDeblendedSource"] & ~isPseudo), 

259 np.sum(outputCat["detect_isDeblendedModelSource"])) 

260 

261 # Check that the sources contained in a tract are all marked appropriately 

262 x = outputCat["slot_Centroid_x"] 

263 y = outputCat["slot_Centroid_y"] 

264 tractInner = tractBBox.contains(x, y) 

265 np.testing.assert_array_equal(outputCat["detect_isTractInner"], tractInner) 

266 

267 # Check that the sources contained in a patch are all marked appropriately 

268 patchInner = patchBBox.contains(x, y) 

269 np.testing.assert_array_equal(outputCat["detect_isPatchInner"], patchInner) 

270 

271 # make sure all sky sources are flagged as not primary 

272 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["merge_peak_sky"])), 0) 

273 

274 # Check that sky objects have not been deblended 

275 np.testing.assert_array_equal( 

276 isPseudo, 

277 isPseudo & (outputCat["deblend_nChild"] == 0) 

278 ) 

279 

280 

281class MemoryTester(lsst.utils.tests.MemoryTestCase): 

282 pass 

283 

284 

285def setup_module(module): 

286 lsst.utils.tests.init() 

287 

288 

289if __name__ == "__main__": 289 ↛ 290line 289 didn't jump to line 290, because the condition on line 289 was never true

290 lsst.utils.tests.init() 

291 unittest.main()