Coverage for tests / test_isPrimaryFlag.py: 25%

165 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 08:53 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23import unittest 

24import numpy as np 

25 

26from lsst.geom import Point2I, Box2I, Extent2I 

27from lsst.skymap import TractInfo 

28from lsst.skymap.patchInfo import PatchInfo 

29import lsst.afw.image as afwImage 

30import lsst.utils.tests 

31from lsst.pipe.tasks.characterizeImage import CharacterizeImageTask, CharacterizeImageConfig 

32from lsst.pipe.tasks.calibrate import CalibrateTask, CalibrateConfig 

33from lsst.meas.algorithms import SourceDetectionTask, SkyObjectsTask, SetPrimaryFlagsTask 

34import lsst.meas.extensions.scarlet as mes 

35from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask 

36from lsst.meas.extensions.scarlet.deconvolveExposureTask import DeconvolveExposureTask 

37from lsst.meas.base import SingleFrameMeasurementTask 

38from lsst.afw.table import SourceCatalog 

39 

40TESTDIR = os.path.abspath(os.path.dirname(__file__)) 

41 

42 

43class NullTract(TractInfo): 

44 """A Tract not contained in the MockSkyMap. 

45 

46 BaseSkyMap.findTract(coord) will always return a Tract, 

47 even if the coord isn't located in the Tract. 

48 In order to mimick this functionality we create a 

49 NullTract for regions of the MockSkyMap that 

50 aren't contained in any of the tracts. 

51 """ 

52 def __init__(self): 

53 pass 

54 

55 def getId(self): 

56 return None 

57 

58 

59class MockTractInfo: 

60 """A Tract based on a bounding box and WCS. 

61 

62 Testing is made easier when we can specifically define 

63 a Tract in terms of its bounding box in pixel coordinates 

64 along with a WCS for the exposure. 

65 

66 Only the relevant methods from `TractInfo` needed to make 

67 test pass are implemented here. Since this is just for 

68 testing, it isn't sophisticated and requires developers to 

69 ensure that the size of the bounding box is evenly divisible 

70 by the number of patches in the Tract. 

71 """ 

72 def __init__(self, name, bbox, wcs, numPatches): 

73 self.name = name 

74 self.bbox = bbox 

75 self.wcs = wcs 

76 self._numPatches = numPatches 

77 assert bbox.getWidth()%numPatches[0] == 0 

78 assert bbox.getHeight()%numPatches[1] == 0 

79 self.patchWidth = bbox.getWidth()//numPatches[0] 

80 self.patchHeight = bbox.getHeight()//numPatches[1] 

81 

82 def contains(self, coord): 

83 pixel = self.wcs.skyToPixel(coord) 

84 return self.bbox.contains(Point2I(pixel)) 

85 

86 def getId(self): 

87 return self.name 

88 

89 def getNumPatches(self): 

90 return self._numPatches 

91 

92 def getPatchInfo(self, index): 

93 x, y = index 

94 width = self.patchWidth 

95 height = self.patchHeight 

96 

97 x = x*self.patchWidth 

98 y = y*self.patchHeight 

99 

100 bbox = Box2I(Point2I(x, y), Extent2I(width, height)) 

101 

102 nx, ny = self._numPatches 

103 sequentialIndex = nx*y + x 

104 

105 patchInfo = PatchInfo( 

106 index=index, 

107 innerBBox=bbox, 

108 outerBBox=bbox, 

109 sequentialIndex=sequentialIndex, 

110 tractWcs=self.wcs 

111 ) 

112 return patchInfo 

113 

114 def __getitem__(self, index): 

115 return self.getPatchInfo(index) 

116 

117 def __iter__(self): 

118 xNum, yNum = self.getNumPatches() 

119 for y in range(yNum): 

120 for x in range(xNum): 

121 yield self.getPatchInfo((x, y)) 

122 

123 

124class MockSkyMap: 

125 """A SkyMap based on a list of bounding boxes. 

126 

127 Testing is made easier when we can specifically define 

128 a Tract in terms of its bounding box in pixel coordinates 

129 along with a WCS for the exposure. This class allows us 

130 to define the tract(s) in the SkyMap and create 

131 them. 

132 """ 

133 def __init__(self, bboxes, wcs, numPatches): 

134 self.bboxes = bboxes 

135 self.wcs = wcs 

136 self.numPatches = numPatches 

137 

138 def __iter__(self): 

139 for b, bbox in enumerate(self.bboxes): 

140 yield self.generateTract(b) 

141 

142 def __getitem__(self, index): 

143 return self.generateTract(index) 

144 

145 def generateTract(self, index): 

146 return MockTractInfo(index, self.bboxes[index], self.wcs, self.numPatches) 

147 

148 def findTract(self, coord): 

149 for tractInfo in self: 

150 if tractInfo.contains(coord): 

151 return tractInfo 

152 

153 return NullTract() 

154 

155 

156class IsPrimaryTestCase(lsst.utils.tests.TestCase): 

157 

158 def setUp(self): 

159 # Load sample input from disk 

160 expPath = os.path.join(TESTDIR, "data", "v695833-e0-c000-a00.sci.fits") 

161 self.exposure = afwImage.ExposureF(expPath) 

162 

163 # Characterize the image (create PSF, etc.) 

164 charImConfig = CharacterizeImageConfig() 

165 charImConfig.measureApCorr.sourceSelector["science"].doSignalToNoise = False 

166 charImTask = CharacterizeImageTask(config=charImConfig) 

167 self.charImResults = charImTask.run(self.exposure) 

168 

169 def tearDown(self): 

170 del self.exposure 

171 self.charImResults 

172 

173 def testIsSinglePrimaryFlag(self): 

174 """Tests detect_isPrimary column gets added when run, and that sources 

175 labelled as detect_isPrimary are not sky sources and have no children. 

176 """ 

177 calibConfig = CalibrateConfig() 

178 calibConfig.doAstrometry = False 

179 calibConfig.doPhotoCal = False 

180 calibConfig.doComputeSummaryStats = False 

181 calibTask = CalibrateTask(config=calibConfig) 

182 calibResults = calibTask.run(self.charImResults.exposure) 

183 outputCat = calibResults.outputCat 

184 self.assertTrue("detect_isPrimary" in outputCat.schema.getNames()) 

185 # make sure all sky sources are flagged as not primary 

186 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["sky_source"])), 0) 

187 # make sure all parent sources are flagged as not primary 

188 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["deblend_nChild"] > 0)), 0) 

189 

190 with self.assertRaises(KeyError): 

191 outputCat.getSchema().find("detect_isDelendedModelPrimary") 

192 

193 def testIsScarletPrimaryFlag(self): 

194 """Test detect_isPrimary column when scarlet is used as the deblender 

195 """ 

196 # We need a multiband coadd for scarlet, 

197 # even though there is only one band 

198 coadds = afwImage.MultibandExposure.fromExposures(["test"], [self.exposure]) 

199 

200 # Create a SkyMap with a tract that contains a portion of the image, 

201 # subdivided into 3x3 patches 

202 wcs = self.exposure.getWcs() 

203 tractBBox = Box2I(Point2I(100, 100), Extent2I(900, 900)) 

204 skyMap = MockSkyMap([tractBBox], wcs, (3, 3)) 

205 tractInfo = skyMap[0] 

206 patchInfo = tractInfo[0, 0] 

207 patchBBox = patchInfo.getInnerBBox() 

208 

209 schema = SourceCatalog.Table.makeMinimalSchema() 

210 # Initialize the detection task 

211 detectionTask = SourceDetectionTask(schema=schema) 

212 

213 # Initialize the fake source injection task 

214 skyConfig = SkyObjectsTask.ConfigClass() 

215 skySourcesTask = SkyObjectsTask(name="skySources", config=skyConfig) 

216 schema.addField("merge_peak_sky", type="Flag") 

217 

218 # Initialize the deconvolution task 

219 deconvolveConfig = DeconvolveExposureTask.ConfigClass() 

220 deconvolveTask = DeconvolveExposureTask(config=deconvolveConfig) 

221 

222 # Initialize the deblender task 

223 scarletConfig = ScarletDeblendTask.ConfigClass() 

224 scarletConfig.maxIter = 20 

225 scarletConfig.columnInheritance["merge_peak_sky"] = "merge_peak_sky" 

226 scarletConfig.processSingles = True 

227 deblendTask = ScarletDeblendTask(schema=schema, config=scarletConfig) 

228 

229 # We'll customize the configuration of measurement to just run the 

230 # minimal number of plugins to make setPrimaryFlags work. 

231 # As of DM-51670 we also include `base_PsfFlux` to ensure that 

232 # the measurement plugins run correctly with the split between 

233 # parent and child catalogs. 

234 measureConfig = SingleFrameMeasurementTask.ConfigClass() 

235 measureConfig.plugins.names = ["base_SdssCentroid", "base_SkyCoord", "base_PsfFlux"] 

236 measureConfig.slots.apFlux = None 

237 measureConfig.slots.shape = None 

238 measureConfig.slots.modelFlux = None 

239 measureConfig.slots.calibFlux = None 

240 measureConfig.slots.gaussianFlux = None 

241 measureTask = SingleFrameMeasurementTask(config=measureConfig, schema=schema) 

242 setPrimaryTask = SetPrimaryFlagsTask(schema=schema, isSingleFrame=False) 

243 

244 table = SourceCatalog.Table.make(schema) 

245 # detect sources 

246 detectionResult = detectionTask.run(table, coadds["test"]) 

247 catalog = detectionResult.sources 

248 # add fake sources 

249 skySources = skySourcesTask.run(mask=self.exposure.mask, seed=0) 

250 for foot in skySources[:5]: 

251 src = catalog.addNew() 

252 src.setFootprint(foot) 

253 src.set("merge_peak_sky", True) 

254 # deconvolve the images 

255 deconvolved = deconvolveTask.run(coadds["test"], catalog).deconvolved 

256 mDeconvolved = afwImage.MultibandExposure.fromExposures(["test"], [deconvolved]) 

257 # deblend 

258 # This is a hack because the variance is not calibrated properly 

259 # (it is 3 orders of magnitude too high), which causes the deblender 

260 # to improperly deblend most sources due to the sparsity constraint. 

261 coadds.variance.array[:] = 2e-1 

262 mDeconvolved.variance.array[:] = 2e-1 

263 result = deblendTask.run(coadds, mDeconvolved, catalog) 

264 modelData = result.scarletModelData 

265 catalog = result.deblendedCatalog 

266 # Attach footprints to the catalog 

267 mes.io.updateCatalogFootprints( 

268 modelData=modelData, 

269 catalog=catalog, 

270 band="test", 

271 imageForRedistribution=coadds["test"], 

272 removeScarletData=True, 

273 updateFluxColumns=True, 

274 ) 

275 

276 # measure 

277 measureTask.run(catalog, self.exposure) 

278 outputCat = catalog 

279 # Set the primary flags 

280 setPrimaryTask.run(outputCat, skyMap=skyMap, tractInfo=tractInfo, patchInfo=patchInfo) 

281 

282 # There should be the same number of deblenedPrimary and 

283 # deblendedModelPrimary sources, 

284 # since they both have the same blended sources and only differ 

285 # over which model to use for the isolated sources. 

286 isPseudo = outputCat["merge_peak_sky"] 

287 

288 # Check that all 5 pseudo-sources were created 

289 self.assertEqual(np.sum(isPseudo), 5) 

290 

291 self.assertEqual( 

292 np.sum(outputCat["detect_isDeblendedSource"] & ~isPseudo), 

293 np.sum(outputCat["detect_isDeblendedModelSource"])) 

294 

295 # Check that the sources contained in a tract are all marked appropriately 

296 x = outputCat["slot_Centroid_x"] 

297 y = outputCat["slot_Centroid_y"] 

298 tractInner = tractBBox.contains(x, y) 

299 np.testing.assert_array_equal(outputCat["detect_isTractInner"], tractInner) 

300 

301 # Check that the sources contained in a patch are all marked appropriately 

302 patchInner = patchBBox.contains(x, y) 

303 np.testing.assert_array_equal(outputCat["detect_isPatchInner"], patchInner) 

304 

305 # make sure all sky sources are flagged as not primary 

306 self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["merge_peak_sky"])), 0) 

307 

308 # Check that sky objects have not been deblended 

309 # (deblended sources have parent > 0) 

310 np.testing.assert_array_equal( 

311 isPseudo, 

312 isPseudo & (outputCat["parent"] == 0) 

313 ) 

314 

315 # Check that measurements were performed on all of the children 

316 self.assertTrue(np.all(outputCat["base_PsfFlux_instFlux"] != 0) and np.all(np.isfinite( 

317 outputCat["base_PsfFlux_instFlux"]))) 

318 

319 

320class MemoryTester(lsst.utils.tests.MemoryTestCase): 

321 pass 

322 

323 

324def setup_module(module): 

325 lsst.utils.tests.init() 

326 

327 

328if __name__ == "__main__": 328 ↛ 329line 328 didn't jump to line 329 because the condition on line 328 was never true

329 lsst.utils.tests.init() 

330 unittest.main()