Coverage for tests / test_matchSourceInjected.py: 22%
83 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 09:21 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 09:21 +0000
1#
2# This file is part of ap_pipe.
3#
4# Developed for the LSST Data Management System.
5# This product includes software developed by the LSST Project
6# (http://www.lsst.org).
7# See the COPYRIGHT file at the top-level directory of this distribution
8# for details of code ownership.
9#
10# This program is free software: you can redistribute it and/or modify
11# it under the terms of the GNU General Public License as published by
12# the Free Software Foundation, either version 3 of the License, or
13# (at your option) any later version.
14#
15# This program is distributed in the hope that it will be useful,
16# but WITHOUT ANY WARRANTY; without even the implied warranty of
17# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18# GNU General Public License for more details.
19#
20# You should have received a copy of the GNU General Public License
21# along with this program. If not, see <http://www.gnu.org/licenses/>.
22#
24import numpy as np
25import pandas as pd
26import unittest
28from astropy.table import Table
29from lsst.afw.table import SourceCatalog, SourceTable
30from lsst.afw.image import ExposureF
31from lsst.afw.geom import makeSkyWcs
32from lsst.geom import SpherePoint, degrees, Point2D, Extent2I
33import lsst.utils.tests
35from lsst.pipe.tasks.matchDiffimSourceInjected import (
36 MatchInjectedToDiaSourceTask,
37 MatchInjectedToAssocDiaSourceTask,
38 MatchInjectedToDiaSourceConfig,
39 MatchInjectedToAssocDiaSourceConfig,
40)
43class BaseTestMatchInjected(lsst.utils.tests.TestCase):
44 def setUp(self):
45 rng = np.random.RandomState(6)
46 # 0.35 arcsec = 0.5/np.sqrt(2) arcsec (half of the diagonal of a 0.5 arcsec square)
47 offsetFactor = 1./3600 * 0.35
49 # Create a mock injected catalog
50 self.injectedCat = Table(
51 {
52 "injection_id": [1, 2, 3, 5, 6, 7, 12, 21, 31, 49],
53 "injection_flag": np.repeat(0, 10),
54 # random positions with 10 arcmin size
55 # ra centered at 0, dec centered at -30
56 "ra": (1/6.) * rng.random(size=10),
57 "dec": -30 + (1/6.) * rng.random(size=10),
58 "mag": 20.25 - 0.5 * rng.random(size=10), # random magnitudes
59 "source_type": np.repeat("DeltaFunction", 10)
60 }
61 )
63 # Create a mock diaSources catalog
64 schema = SourceTable.makeMinimalSchema()
65 self.diaSources = SourceCatalog(schema)
66 for i in range(5):
67 record = self.diaSources.addNew()
68 record.setId(100 + i)
69 record.setCoord(
70 # Random posisions of diaSources
71 SpherePoint((1/6.) * rng.random(), -30 + (1/6.) * rng.random(), degrees)
72 )
74 for i, (ra, dec) in enumerate(self.injectedCat[['ra', 'dec']][:8]):
75 record = self.diaSources.addNew()
76 sign = rng.choice([-1, 1], size=2)
77 record.setId(i + 200)
78 record.setCoord(
79 SpherePoint(
80 ra+sign[0]*rng.random()*offsetFactor,
81 dec+sign[1]*rng.random()*offsetFactor,
82 degrees
83 )
84 )
86 # Create a mock difference image
87 self.diffIm = ExposureF(Extent2I(4096, 4096))
88 crpix = Point2D(0, 0)
89 crval = SpherePoint(0, -30, degrees)
90 cdMatrix = np.array([[5.19513851e-05, 2.81124812e-07],
91 [3.25186974e-07, 5.19112119e-05]])
92 wcs = makeSkyWcs(crpix, crval, cdMatrix)
93 self.diffIm.setWcs(wcs)
95 # add a fake source outside of the image
96 self.injectedCatForTrimming = self.injectedCat.copy()
97 self.injectedCatForTrimming.add_row(
98 {
99 'injection_id': 50,
100 'injection_flag': 0,
101 'ra': 360. - 0.1/6.,
102 'dec': -30 - 0.1/6.,
103 'mag': 20.0,
104 'source_type': 'DeltaFunction'
105 }
106 )
108 # only 4 injected sources are associated
109 self.assocDiaSources = pd.DataFrame(
110 {
111 "diaSourceId": [101, 102, 103, 201, 202, 205, 207],
112 "band": np.repeat("r", 7),
113 "visit": np.repeat(410001, 7),
114 "detector": np.repeat(0, 7),
115 "diaObjectId": np.arange(7),
116 }
117 )
120class TestMatchInjectedToDiaSourceTask(BaseTestMatchInjected):
122 def test_run(self):
123 config = MatchInjectedToDiaSourceConfig()
124 config.matchDistanceArcseconds = 0.5
125 config.doMatchVisit = False
126 config.doForcedMeasurement = False
128 task = MatchInjectedToDiaSourceTask(config=config)
130 result = task.run(self.injectedCat, self.diffIm, self.diaSources)
131 self.assertEqual(len(result.matchDiaSources), len(self.injectedCat))
132 self.assertEqual(np.sum(result.matchDiaSources['diaSourceId'] > 0), 8)
133 self.assertEqual(np.sum(result.matchDiaSources['dist_diaSrc'] > 0), 8)
134 self.assertEqual(
135 np.sum(np.abs(result.matchDiaSources['dist_diaSrc']) < config.matchDistanceArcseconds), 8
136 )
138 def test_run_trimming(self):
139 config = MatchInjectedToDiaSourceConfig()
140 config.matchDistanceArcseconds = 0.5
141 config.doMatchVisit = True
142 config.doForcedMeasurement = False
144 task = MatchInjectedToDiaSourceTask(config=config)
145 result = task.run(self.injectedCatForTrimming, self.diffIm, self.diaSources)
147 self.assertEqual(len(result.matchDiaSources), len(self.injectedCatForTrimming) - 1)
148 self.assertEqual(np.sum(result.matchDiaSources['diaSourceId'] > 0), 8)
149 self.assertEqual(np.sum(result.matchDiaSources['dist_diaSrc'] > 0), 8)
150 self.assertEqual(
151 np.sum(np.abs(result.matchDiaSources['dist_diaSrc']) < config.matchDistanceArcseconds), 8
152 )
154 def test_getVectors(self):
155 config = MatchInjectedToDiaSourceConfig()
156 config.matchDistanceArcseconds = 0.5
157 config.doMatchVisit = False
158 config.doForcedMeasurement = False
160 task = MatchInjectedToDiaSourceTask(config=config)
162 ras = np.radians(self.injectedCat['ra'])
163 decs = np.radians(self.injectedCat['dec'])
164 vectors = task._getVectors(ras, decs)
165 self.assertEqual(vectors.shape, (10, 3))
168class TestMatchInjectedToAssocDiaSourceTask(BaseTestMatchInjected):
170 def test_run(self):
171 config = MatchInjectedToDiaSourceConfig()
172 config.matchDistanceArcseconds = 0.5
173 config.doMatchVisit = False
174 config.doForcedMeasurement = False
176 task = MatchInjectedToDiaSourceTask(config=config)
178 result = task.run(self.injectedCat, self.diffIm, self.diaSources)
180 configAssoc = MatchInjectedToAssocDiaSourceConfig()
181 taskAssoc = MatchInjectedToAssocDiaSourceTask(config=configAssoc)
182 resultAssoc = taskAssoc.run(self.assocDiaSources, result.matchDiaSources)
183 self.assertEqual(len(resultAssoc.matchAssocDiaSources), len(self.injectedCat))
184 self.assertEqual(np.sum(resultAssoc.matchAssocDiaSources['isAssocDiaSource']), 4)
187if __name__ == "__main__": 187 ↛ 188line 187 didn't jump to line 188 because the condition on line 187 was never true
188 unittest.main()