Coverage for tests / test_utils.py: 23%
135 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 09:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 09:22 +0000
1# This file is part of source_injection.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import logging
23import os
24import unittest
25from contextlib import redirect_stdout
26from io import StringIO
28import numpy as np
30import lsst.utils.tests
31from lsst.daf.butler.tests import makeTestCollection, makeTestRepo
32from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir
33from lsst.obs.base.instrument_tests import DummyCam
34from lsst.pipe.base import Pipeline
35from lsst.skymap.ringsSkyMap import RingsSkyMap, RingsSkyMapConfig
36from lsst.source.injection import (
37 ConsolidateInjectedCatalogsConfig,
38 ConsolidateInjectedCatalogsTask,
39 ExposureInjectTask,
40 ingest_injection_catalog,
41 make_injection_pipeline,
42 show_source_types,
43)
44from lsst.source.injection.utils.test_utils import (
45 make_test_exposure,
46 make_test_injection_catalog,
47 make_test_reference_pipeline,
48)
49from lsst.utils.tests import TestCase
51TEST_DIR = os.path.abspath(os.path.dirname(__file__))
54class SourceInjectionUtilsTestCase(TestCase):
55 """Test the utility functions in the source_injection package."""
57 @classmethod
58 def setUpClass(cls):
59 cls.root = makeTestTempDir(TEST_DIR)
60 cls.creator_butler = makeTestRepo(cls.root)
61 cls.writeable_butler = makeTestCollection(cls.creator_butler)
62 # Register an instrument so we can get some bands.
63 DummyCam().register(cls.writeable_butler.registry)
64 skyMapConfig = RingsSkyMapConfig()
65 skyMapConfig.numRings = 3
66 cls.skyMap = RingsSkyMap(config=skyMapConfig)
67 logging.disable(logging.CRITICAL) # Suppress logging output
69 @classmethod
70 def tearDownClass(cls):
71 del cls.writeable_butler
72 del cls.creator_butler
73 del cls.skyMap
74 removeTestTempDir(cls.root)
75 logging.disable(logging.NOTSET) # Re-enable logging output
77 def setUp(self):
78 self.exposure = make_test_exposure()
79 self.injection_catalog = make_test_injection_catalog(
80 self.exposure.getWcs(),
81 self.exposure.getBBox(),
82 )
83 n_rows = len(self.injection_catalog)
84 group_ids = np.arange(n_rows)
85 group_ids[int(n_rows / 4) : int((n_rows * 3) / 4) : 2] -= 1
86 self.injection_catalog["group_id"] = group_ids
87 self.reference_pipeline = make_test_reference_pipeline()
88 self.consolidate_injected_config = ConsolidateInjectedCatalogsConfig(
89 get_catalogs_from_butler=False,
90 )
91 self.injected_catalog = self.injection_catalog.copy()
92 self.injected_catalog.add_columns(cols=[0, 0], names=["injection_draw_size", "injection_flag"])
93 self.injected_catalog["injection_flag"][:5] = 1
95 def tearDown(self):
96 del self.exposure
97 del self.injection_catalog
98 del self.reference_pipeline
99 del self.injected_catalog
101 def test_generate_injection_catalog(self):
102 self.assertEqual(len(self.injection_catalog), 30)
103 expected_columns = {"injection_id", "ra", "dec", "source_type", "mag", "group_id"}
104 self.assertEqual(set(self.injection_catalog.columns), expected_columns)
106 def test_make_injection_pipeline(self):
107 injection_pipeline = Pipeline("injection_pipeline")
108 injection_pipeline.addTask(ExposureInjectTask, "inject_exposure")
110 additional_pipeline = Pipeline("additional_pipeline")
111 additional_pipeline.addTask(ConsolidateInjectedCatalogsTask, "additional_task")
113 # Explicitly set connection names to non-default values.
114 injection_pipeline.addConfigOverride("inject_exposure", "connections.input_exposure", "A")
115 injection_pipeline.addConfigOverride("inject_exposure", "connections.output_exposure", "B")
116 injection_pipeline.addConfigOverride("inject_exposure", "connections.output_catalog", "C")
118 # Merge the injection pipeline into the main reference pipeline.
119 merged_pipeline = make_injection_pipeline(
120 dataset_type_name="postISRCCD", # Unchanged to match task default
121 reference_pipeline=self.reference_pipeline,
122 injection_pipeline=injection_pipeline,
123 exclude_subsets=False,
124 excluded_tasks={"calibrate"},
125 prefix="injected_",
126 instrument="lsst.obs.subaru.HyperSuprimeCam",
127 additional_pipelines=[additional_pipeline],
128 additional_subset=["newSubset:newSubset description"],
129 log_level=logging.DEBUG,
130 )
132 # Test that only the expected tasks are present in the merged pipeline.
133 expected_task_labels = set(self.reference_pipeline.task_labels) - {"calibrate"}
134 surviving_task_labels = set(self.reference_pipeline.task_labels) & set(merged_pipeline.task_labels)
135 self.assertEqual(expected_task_labels, surviving_task_labels)
137 # Test that all surviving tasks are still in a subset.
138 surviving_task_subsets = [merged_pipeline.findSubsetsWithLabel(x) for x in surviving_task_labels]
139 self.assertEqual(sum(1 for s in surviving_task_subsets if s), len(surviving_task_labels))
140 self.assertIn("newSubset", merged_pipeline.findSubsetsWithLabel("additional_task"))
142 # Test that connection names have been properly configured.
143 for t in merged_pipeline.to_graph().tasks.values():
144 if t.label == "isr":
145 self.assertEqual(t.outputs["outputExposure"].dataset_type_name, "postISRCCD")
146 elif t.label == "inject_exposure":
147 self.assertEqual(t.inputs["input_exposure"].dataset_type_name, "postISRCCD")
148 self.assertEqual(t.outputs["output_exposure"].dataset_type_name, "injected_postISRCCD")
149 self.assertEqual(t.outputs["output_catalog"].dataset_type_name, "injected_postISRCCD_catalog")
150 elif t.label == "characterizeImage":
151 self.assertEqual(t.inputs["exposure"].dataset_type_name, "injected_postISRCCD")
152 self.assertEqual(t.outputs["characterized"].dataset_type_name, "injected_icExp")
153 self.assertEqual(t.outputs["backgroundModel"].dataset_type_name, "injected_icExpBackground")
154 self.assertEqual(t.outputs["sourceCat"].dataset_type_name, "injected_icSrc")
156 def test_ingest_injection_catalog(self):
157 input_dataset_refs = ingest_injection_catalog(
158 writeable_butler=self.writeable_butler,
159 table=self.injection_catalog,
160 band="g",
161 output_collection="test_collection",
162 dataset_type_name="injection_catalog",
163 log_level=logging.DEBUG,
164 )
165 output_dataset_refs = self.writeable_butler.registry.queryDatasets(
166 "injection_catalog",
167 collections="test_collection",
168 )
169 self.assertEqual(len(input_dataset_refs), output_dataset_refs.count())
170 input_ids = {x.id for x in input_dataset_refs}
171 output_ids = {x.id for x in output_dataset_refs}
172 self.assertEqual(input_ids, output_ids)
173 injected_catalog = self.writeable_butler.get(input_dataset_refs[0])
174 self.assertTrue(all(self.injection_catalog == injected_catalog))
176 def test_consolidate_injected_catalogs(self):
177 catalog_dict = {"g": self.injected_catalog, "r": self.injected_catalog}
178 output_catalog = self.consolidate_injected_config.consolidate_catalogs(
179 catalog_dict=catalog_dict,
180 skymap=self.skyMap,
181 tract=9,
182 copy_catalogs=True,
183 )
184 self.assertEqual(len(output_catalog), 30)
185 expected_columns = [
186 "injected_id",
187 "ra",
188 "dec",
189 "source_type",
190 "g_mag",
191 "r_mag",
192 "patch",
193 "injection_id",
194 "injection_draw_size",
195 "injection_flag",
196 "injected_isPatchInner",
197 "injected_isTractInner",
198 "injected_isPrimary",
199 "group_id",
200 "g_injection_flag",
201 "r_injection_flag",
202 ]
203 self.assertListEqual(output_catalog.colnames, expected_columns)
204 self.assertEqual(sum(output_catalog["injection_flag"]), 5)
205 self.assertEqual(sum(output_catalog["injected_isPatchInner"]), 30)
206 self.assertEqual(sum(output_catalog["injected_isTractInner"]), 30)
207 self.assertEqual(sum(output_catalog["injected_isPrimary"]), 25)
209 def test_consolidate_injected_catalog_task(self):
210 group_id_key = "group_id"
211 config = ConsolidateInjectedCatalogsConfig(
212 groupIdKey=group_id_key,
213 pixel_match_radius=-1,
214 columns_extra=[],
215 get_catalogs_from_butler=False,
216 )
217 task = ConsolidateInjectedCatalogsTask(config=config)
218 catalog_dict = {"g": self.injected_catalog, "r": self.injected_catalog}
219 output_catalog = task.run(
220 catalog_dict=catalog_dict,
221 skymap=self.skyMap,
222 tract=9,
223 ).output_catalog
224 groupIds, counts = np.unique(
225 self.injection_catalog[group_id_key],
226 return_counts=True,
227 )
228 n_comps = np.max(counts)
229 self.assertEqual(len(output_catalog), len(groupIds))
230 expected_columns = [
231 config.groupIdKey,
232 config.injectionKey,
233 config.col_ra,
234 config.col_dec,
235 ]
236 for band in catalog_dict.keys():
237 columns_band = [
238 f"{band}_{config.injectionKey}",
239 f"{band}_{config.col_mag}",
240 ]
241 for compnum in range(1, n_comps + 1):
242 columns_band.extend(
243 [
244 f"{band}_comp{compnum}_source_type",
245 f"{band}_comp{compnum}_{config.injectionKey}",
246 ]
247 )
248 expected_columns.extend(columns_band)
249 expected_columns.extend(
250 [
251 "patch",
252 "injected_isPatchInner",
253 "injected_isTractInner",
254 "injected_isPrimary",
255 "injected_id",
256 ]
257 )
258 self.assertEqual(set(output_catalog.colnames), set(expected_columns))
259 self.assertEqual(sum(output_catalog["injection_flag"]), 5)
260 self.assertEqual(sum(output_catalog["injected_isPatchInner"]), 22)
261 self.assertEqual(sum(output_catalog["injected_isTractInner"]), 22)
262 self.assertEqual(sum(output_catalog["injected_isPrimary"]), 17)
264 def test_show_source_types(self):
265 buffer = StringIO()
266 with redirect_stdout(buffer):
267 show_source_types(wrap_width=80)
268 output = buffer.getvalue()
269 self.assertIn(
270 "Sersic:\n"
271 " (n, half_light_radius=None, scale_radius=None, mag=None, trunc=0.0,\n"
272 " flux_untruncated=False)",
273 output,
274 )
277class MemoryTestCase(lsst.utils.tests.MemoryTestCase):
278 """Test memory usage of functions in this script."""
280 pass
283def setup_module(module):
284 """Configure pytest."""
285 lsst.utils.tests.init()
288if __name__ == "__main__": 288 ↛ 289line 288 didn't jump to line 289 because the condition on line 288 was never true
289 lsst.utils.tests.init()
290 unittest.main()