Coverage for tests / test_utils.py: 23%

135 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 09:11 +0000

1# This file is part of source_injection. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import logging 

23import os 

24import unittest 

25from contextlib import redirect_stdout 

26from io import StringIO 

27 

28import numpy as np 

29 

30import lsst.utils.tests 

31from lsst.daf.butler.tests import makeTestCollection, makeTestRepo 

32from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir 

33from lsst.obs.base.instrument_tests import DummyCam 

34from lsst.pipe.base import Pipeline 

35from lsst.skymap.ringsSkyMap import RingsSkyMap, RingsSkyMapConfig 

36from lsst.source.injection import ( 

37 ConsolidateInjectedCatalogsConfig, 

38 ConsolidateInjectedCatalogsTask, 

39 ExposureInjectTask, 

40 ingest_injection_catalog, 

41 make_injection_pipeline, 

42 show_source_types, 

43) 

44from lsst.source.injection.utils.test_utils import ( 

45 make_test_exposure, 

46 make_test_injection_catalog, 

47 make_test_reference_pipeline, 

48) 

49from lsst.utils.tests import TestCase 

50 

51TEST_DIR = os.path.abspath(os.path.dirname(__file__)) 

52 

53 

54class SourceInjectionUtilsTestCase(TestCase): 

55 """Test the utility functions in the source_injection package.""" 

56 

57 @classmethod 

58 def setUpClass(cls): 

59 cls.root = makeTestTempDir(TEST_DIR) 

60 cls.creator_butler = makeTestRepo(cls.root) 

61 cls.writeable_butler = makeTestCollection(cls.creator_butler) 

62 # Register an instrument so we can get some bands. 

63 DummyCam().register(cls.writeable_butler.registry) 

64 skyMapConfig = RingsSkyMapConfig() 

65 skyMapConfig.numRings = 3 

66 cls.skyMap = RingsSkyMap(config=skyMapConfig) 

67 logging.disable(logging.CRITICAL) # Suppress logging output 

68 

69 @classmethod 

70 def tearDownClass(cls): 

71 del cls.writeable_butler 

72 del cls.creator_butler 

73 del cls.skyMap 

74 removeTestTempDir(cls.root) 

75 logging.disable(logging.NOTSET) # Re-enable logging output 

76 

77 def setUp(self): 

78 self.exposure = make_test_exposure() 

79 self.injection_catalog = make_test_injection_catalog( 

80 self.exposure.getWcs(), 

81 self.exposure.getBBox(), 

82 ) 

83 n_rows = len(self.injection_catalog) 

84 group_ids = np.arange(n_rows) 

85 group_ids[int(n_rows / 4) : int((n_rows * 3) / 4) : 2] -= 1 

86 self.injection_catalog["group_id"] = group_ids 

87 self.reference_pipeline = make_test_reference_pipeline() 

88 self.consolidate_injected_config = ConsolidateInjectedCatalogsConfig( 

89 get_catalogs_from_butler=False, 

90 ) 

91 self.injected_catalog = self.injection_catalog.copy() 

92 self.injected_catalog.add_columns(cols=[0, 0], names=["injection_draw_size", "injection_flag"]) 

93 self.injected_catalog["injection_flag"][:5] = 1 

94 

95 def tearDown(self): 

96 del self.exposure 

97 del self.injection_catalog 

98 del self.reference_pipeline 

99 del self.injected_catalog 

100 

101 def test_generate_injection_catalog(self): 

102 self.assertEqual(len(self.injection_catalog), 30) 

103 expected_columns = {"injection_id", "ra", "dec", "source_type", "mag", "group_id"} 

104 self.assertEqual(set(self.injection_catalog.columns), expected_columns) 

105 

106 def test_make_injection_pipeline(self): 

107 injection_pipeline = Pipeline("injection_pipeline") 

108 injection_pipeline.addTask(ExposureInjectTask, "inject_exposure") 

109 

110 additional_pipeline = Pipeline("additional_pipeline") 

111 additional_pipeline.addTask(ConsolidateInjectedCatalogsTask, "additional_task") 

112 

113 # Explicitly set connection names to non-default values. 

114 injection_pipeline.addConfigOverride("inject_exposure", "connections.input_exposure", "A") 

115 injection_pipeline.addConfigOverride("inject_exposure", "connections.output_exposure", "B") 

116 injection_pipeline.addConfigOverride("inject_exposure", "connections.output_catalog", "C") 

117 

118 # Merge the injection pipeline into the main reference pipeline. 

119 merged_pipeline = make_injection_pipeline( 

120 dataset_type_name="postISRCCD", 

121 reference_pipeline=self.reference_pipeline, 

122 injection_pipeline=injection_pipeline, 

123 exclude_subsets=False, 

124 excluded_tasks={"calibrate"}, 

125 prefix="injected_", 

126 instrument="lsst.obs.subaru.HyperSuprimeCam", 

127 additional_pipelines=[additional_pipeline], 

128 additional_subset=["newSubset:newSubset description"], 

129 log_level=logging.DEBUG, 

130 ) 

131 

132 # Test that only the expected tasks are present in the merged pipeline. 

133 expected_task_labels = set(self.reference_pipeline.task_labels) - {"calibrate"} 

134 surviving_task_labels = set(self.reference_pipeline.task_labels) & set(merged_pipeline.task_labels) 

135 self.assertEqual(expected_task_labels, surviving_task_labels) 

136 

137 # Test that all surviving tasks are still in a subset. 

138 surviving_task_subsets = [merged_pipeline.findSubsetsWithLabel(x) for x in surviving_task_labels] 

139 self.assertEqual(sum(1 for s in surviving_task_subsets if s), len(surviving_task_labels)) 

140 self.assertIn("newSubset", merged_pipeline.findSubsetsWithLabel("additional_task")) 

141 

142 # Test that connection names have been properly configured. 

143 for t in merged_pipeline.to_graph().tasks.values(): 

144 if t.label == "isr": 

145 self.assertEqual(t.outputs["outputExposure"].dataset_type_name, "postISRCCD") 

146 elif t.label == "inject_exposure": 

147 self.assertEqual(t.inputs["input_exposure"].dataset_type_name, "postISRCCD") 

148 self.assertEqual(t.outputs["output_exposure"].dataset_type_name, "injected_postISRCCD") 

149 self.assertEqual(t.outputs["output_catalog"].dataset_type_name, "injected_postISRCCD_catalog") 

150 elif t.label == "characterizeImage": 

151 self.assertEqual(t.inputs["exposure"].dataset_type_name, "injected_postISRCCD") 

152 self.assertEqual(t.outputs["characterized"].dataset_type_name, "injected_icExp") 

153 self.assertEqual(t.outputs["backgroundModel"].dataset_type_name, "injected_icExpBackground") 

154 self.assertEqual(t.outputs["sourceCat"].dataset_type_name, "injected_icSrc") 

155 

156 def test_ingest_injection_catalog(self): 

157 input_dataset_refs = ingest_injection_catalog( 

158 writeable_butler=self.writeable_butler, 

159 table=self.injection_catalog, 

160 band="g", 

161 output_collection="test_collection", 

162 dataset_type_name="injection_catalog", 

163 log_level=logging.DEBUG, 

164 ) 

165 output_dataset_refs = self.writeable_butler.registry.queryDatasets( 

166 "injection_catalog", 

167 collections="test_collection", 

168 ) 

169 self.assertEqual(len(input_dataset_refs), output_dataset_refs.count()) 

170 input_ids = {x.id for x in input_dataset_refs} 

171 output_ids = {x.id for x in output_dataset_refs} 

172 self.assertEqual(input_ids, output_ids) 

173 injected_catalog = self.writeable_butler.get(input_dataset_refs[0]) 

174 self.assertTrue(all(self.injection_catalog == injected_catalog)) 

175 

176 def test_consolidate_injected_catalogs(self): 

177 catalog_dict = {"g": self.injected_catalog, "r": self.injected_catalog} 

178 output_catalog = self.consolidate_injected_config.consolidate_deepCoadd( 

179 catalog_dict=catalog_dict, 

180 skymap=self.skyMap, 

181 tract=9, 

182 copy_catalogs=True, 

183 ) 

184 self.assertEqual(len(output_catalog), 30) 

185 expected_columns = [ 

186 "injected_id", 

187 "ra", 

188 "dec", 

189 "source_type", 

190 "g_mag", 

191 "r_mag", 

192 "patch", 

193 "injection_id", 

194 "injection_draw_size", 

195 "injection_flag", 

196 "injected_isPatchInner", 

197 "injected_isTractInner", 

198 "injected_isPrimary", 

199 "group_id", 

200 "g_injection_flag", 

201 "r_injection_flag", 

202 ] 

203 self.assertListEqual(output_catalog.colnames, expected_columns) 

204 self.assertEqual(sum(output_catalog["injection_flag"]), 5) 

205 self.assertEqual(sum(output_catalog["injected_isPatchInner"]), 30) 

206 self.assertEqual(sum(output_catalog["injected_isTractInner"]), 30) 

207 self.assertEqual(sum(output_catalog["injected_isPrimary"]), 25) 

208 

209 def test_consolidate_injected_catalog_task(self): 

210 group_id_key = "group_id" 

211 config = ConsolidateInjectedCatalogsConfig( 

212 groupIdKey=group_id_key, 

213 pixel_match_radius=-1, 

214 columns_extra=[], 

215 get_catalogs_from_butler=False, 

216 ) 

217 task = ConsolidateInjectedCatalogsTask(config=config) 

218 catalog_dict = {"g": self.injected_catalog, "r": self.injected_catalog} 

219 output_catalog = task.run( 

220 catalog_dict=catalog_dict, 

221 skymap=self.skyMap, 

222 tract=9, 

223 ).output_catalog 

224 groupIds, counts = np.unique( 

225 self.injection_catalog[group_id_key], 

226 return_counts=True, 

227 ) 

228 n_comps = np.max(counts) 

229 self.assertEqual(len(output_catalog), len(groupIds)) 

230 expected_columns = [ 

231 config.groupIdKey, 

232 config.injectionKey, 

233 config.col_ra, 

234 config.col_dec, 

235 ] 

236 for band in catalog_dict.keys(): 

237 columns_band = [ 

238 f"{band}_{config.injectionKey}", 

239 f"{band}_{config.col_mag}", 

240 ] 

241 for compnum in range(1, n_comps + 1): 

242 columns_band.extend( 

243 [ 

244 f"{band}_comp{compnum}_source_type", 

245 f"{band}_comp{compnum}_{config.injectionKey}", 

246 ] 

247 ) 

248 expected_columns.extend(columns_band) 

249 expected_columns.extend( 

250 [ 

251 "patch", 

252 "injected_isPatchInner", 

253 "injected_isTractInner", 

254 "injected_isPrimary", 

255 "injected_id", 

256 ] 

257 ) 

258 self.assertEqual(set(output_catalog.colnames), set(expected_columns)) 

259 self.assertEqual(sum(output_catalog["injection_flag"]), 5) 

260 self.assertEqual(sum(output_catalog["injected_isPatchInner"]), 22) 

261 self.assertEqual(sum(output_catalog["injected_isTractInner"]), 22) 

262 self.assertEqual(sum(output_catalog["injected_isPrimary"]), 17) 

263 

264 def test_show_source_types(self): 

265 buffer = StringIO() 

266 with redirect_stdout(buffer): 

267 show_source_types(wrap_width=80) 

268 output = buffer.getvalue() 

269 self.assertIn( 

270 "Sersic:\n" 

271 " (n, half_light_radius=None, scale_radius=None, mag=None, trunc=0.0,\n" 

272 " flux_untruncated=False)", 

273 output, 

274 ) 

275 

276 

277class MemoryTestCase(lsst.utils.tests.MemoryTestCase): 

278 """Test memory usage of functions in this script.""" 

279 

280 pass 

281 

282 

283def setup_module(module): 

284 """Configure pytest.""" 

285 lsst.utils.tests.init() 

286 

287 

288if __name__ == "__main__": 288 ↛ 289line 288 didn't jump to line 289 because the condition on line 288 was never true

289 lsst.utils.tests.init() 

290 unittest.main()