Coverage for tests/test_transformDiaSourceCatalog.py: 18%
131 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-06 03:44 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-06 03:44 -0700
1# This file is part of ap_association
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
23import unittest
25import numpy as np
27from lsst.ap.association.transformDiaSourceCatalog import (TransformDiaSourceCatalogConfig,
28 TransformDiaSourceCatalogTask)
29from lsst.afw.cameraGeom.testUtils import DetectorWrapper
30import lsst.daf.base as dafBase
31import lsst.afw.image as afwImage
32import lsst.geom as geom
33import lsst.meas.base.tests as measTests
34from lsst.pipe.base import Struct
35import lsst.utils.tests
37from lsst.ap.association.transformDiaSourceCatalog import UnpackApdbFlags
39TESTDIR = os.path.abspath(os.path.dirname(__file__))
42class TestTransformDiaSourceCatalogTask(unittest.TestCase):
43 def setUp(self):
44 # The first source will be a sky source.
45 self.nSources = 10
46 # Default PSF size (psfDim in makeEmptyExposure) in TestDataset results
47 # in an 18 pixel wide source box.
48 self.bboxSize = 18
49 self.yLoc = 100
50 self.bbox = geom.Box2I(geom.Point2I(0, 0),
51 geom.Extent2I(1024, 1153))
52 dataset = measTests.TestDataset(self.bbox)
53 for srcIdx in range(self.nSources-1):
54 # Place sources at (index, yLoc), so we can distinguish them later.
55 dataset.addSource(100000.0, geom.Point2D(srcIdx, self.yLoc))
56 # Ensure the last source has no peak `significance` field.
57 dataset.addSource(100000.0, geom.Point2D(srcIdx+1, self.yLoc), setPeakSignificance=False)
58 schema = dataset.makeMinimalSchema()
59 schema.addField("base_PixelFlags_flag", type="Flag")
60 schema.addField("base_PixelFlags_flag_offimage", type="Flag")
61 schema.addField("sky_source", type="Flag", doc="Sky objects.")
62 self.exposure, self.inputCatalog = dataset.realize(10.0, schema, randomSeed=1234)
63 self.inputCatalog[0]['sky_source'] = True
64 # Create schemas for use in initializing the TransformDiaSourceCatalog task.
65 self.initInputs = {"diaSourceSchema": Struct(schema=schema)}
66 self.initInputsBadFlags = {"diaSourceSchema": Struct(schema=dataset.makeMinimalSchema())}
68 # Separate real/bogus score table, indexed on the above catalog ids.
69 reliabilitySchema = lsst.afw.table.Schema()
70 reliabilitySchema.addField(self.inputCatalog.schema["id"].asField())
71 reliabilitySchema.addField("score", doc="real/bogus score of this source", type=float)
72 self.reliability = lsst.afw.table.BaseCatalog(reliabilitySchema)
73 self.reliability.resize(len(self.inputCatalog))
74 self.reliability["id"] = self.inputCatalog["id"]
75 self.reliability["score"] = np.random.random(len(self.inputCatalog))
77 self.expId = 4321
78 self.date = dafBase.DateTime(nsecs=1400000000 * 10**9)
79 detector = DetectorWrapper(id=23, bbox=self.exposure.getBBox()).detector
80 visit = afwImage.VisitInfo(
81 exposureId=self.expId,
82 exposureTime=200.,
83 date=self.date)
84 self.exposure.info.id = self.expId
85 self.exposure.setDetector(detector)
86 self.exposure.info.setVisitInfo(visit)
87 self.band = 'g'
88 self.exposure.setFilter(afwImage.FilterLabel(band=self.band, physical='g.MP9401'))
89 scale = 2
90 scaleErr = 1
91 self.photoCalib = afwImage.PhotoCalib(scale, scaleErr)
92 self.exposure.setPhotoCalib(self.photoCalib)
94 self.config = TransformDiaSourceCatalogConfig()
95 self.config.flagMap = os.path.join(TESTDIR, "data", "test-flag-map.yaml")
96 self.config.functorFile = os.path.join(TESTDIR,
97 "data",
98 "testDiaSource.yaml")
100 def test_run(self):
101 """Test output dataFrame is created and values are correctly inserted
102 from the exposure.
103 """
104 transformTask = TransformDiaSourceCatalogTask(initInputs=self.initInputs,
105 config=self.config)
106 result = transformTask.run(self.inputCatalog,
107 self.exposure,
108 self.band,
109 ccdVisitId=self.expId)
111 self.assertEqual(len(result.diaSourceTable), len(self.inputCatalog))
112 np.testing.assert_array_equal(result.diaSourceTable["bboxSize"], [self.bboxSize]*self.nSources)
113 np.testing.assert_array_equal(result.diaSourceTable["ccdVisitId"], [self.expId]*self.nSources)
114 np.testing.assert_array_equal(result.diaSourceTable["band"], [self.band]*self.nSources)
115 np.testing.assert_array_equal(result.diaSourceTable["midpointMjdTai"],
116 [self.date.get(system=dafBase.DateTime.MJD)]*self.nSources)
117 np.testing.assert_array_equal(result.diaSourceTable["diaObjectId"], [0]*self.nSources)
118 np.testing.assert_array_equal(result.diaSourceTable["x"], np.arange(self.nSources))
119 # The final snr value should be NaN because it doesn't have a peak significance field.
120 expect_snr = [397.887353515625]*9
121 expect_snr.append(np.nan)
122 # Have to use allclose because assert_array_equal doesn't support equal_nan.
123 np.testing.assert_allclose(result.diaSourceTable["snr"], expect_snr, equal_nan=True, rtol=0)
125 def test_run_with_reliability(self):
126 self.config.doIncludeReliability = True
127 transformTask = TransformDiaSourceCatalogTask(initInputs=self.initInputs,
128 config=self.config)
129 result = transformTask.run(self.inputCatalog,
130 self.exposure,
131 self.band,
132 reliability=self.reliability,
133 ccdVisitId=self.expId)
134 self.assertEqual(len(result.diaSourceTable), len(self.inputCatalog))
135 np.testing.assert_array_equal(result.diaSourceTable["reliability"], self.reliability["score"])
137 def test_run_doSkySources(self):
138 """Test that we get the correct output with doSkySources=True; the one
139 sky source should be missing, but the other records should be the same.
141 We only test the fields here that could be different, not the ones that
142 are the same for all sources.
143 """
144 # Make the sky source have a different significance value, to distinguish it.
145 self.inputCatalog[0].getFootprint().updatePeakSignificance(5.0)
147 self.config.doRemoveSkySources = True
148 task = TransformDiaSourceCatalogTask(initInputs=self.initInputs, config=self.config)
149 result = task.run(self.inputCatalog, self.exposure, self.band, ccdVisitId=self.expId)
151 self.assertEqual(len(result.diaSourceTable), self.nSources-1)
152 # 0th source was removed, so x positions of the remaining sources are at x=1,2,3...
153 np.testing.assert_array_equal(result.diaSourceTable["x"], np.arange(self.nSources-1)+1)
154 # The final snr value should be NaN because it doesn't have a peak significance field.
155 expect_snr = [397.887353515625]*8
156 expect_snr.append(np.nan)
157 # Have to use allclose because assert_array_equal doesn't support equal_nan.
158 np.testing.assert_allclose(result.diaSourceTable["snr"], expect_snr, equal_nan=True, rtol=0)
160 def test_run_dia_source_wrong_flags(self):
161 """Test that the proper errors are thrown when requesting flag columns
162 that are not in the input schema.
163 """
164 with self.assertRaises(KeyError):
165 TransformDiaSourceCatalogTask(initInputs=self.initInputsBadFlags)
167 def test_computeBBoxSize(self):
168 transform = TransformDiaSourceCatalogTask(initInputs=self.initInputs,
169 config=self.config)
170 boxSizes = transform.computeBBoxSizes(self.inputCatalog)
172 for size in boxSizes:
173 self.assertEqual(size, self.bboxSize)
174 self.assertEqual(len(boxSizes), self.nSources)
176 def test_bit_unpacker(self):
177 """Test that the integer bit packer is functioning correctly.
178 """
179 transform = TransformDiaSourceCatalogTask(initInputs=self.initInputs,
180 config=self.config)
181 for idx, obj in enumerate(self.inputCatalog):
182 if idx in [1, 3, 5]:
183 obj.set("base_PixelFlags_flag", 1)
184 if idx in [1, 4, 6]:
185 obj.set("base_PixelFlags_flag_offimage", 1)
186 outputCatalog = transform.run(self.inputCatalog,
187 self.exposure,
188 self.band,
189 ccdVisitId=self.expId).diaSourceTable
191 unpacker = UnpackApdbFlags(self.config.flagMap, "DiaSource")
192 flag_values = unpacker.unpack(outputCatalog["flags"], "flags")
194 for idx, flag in enumerate(flag_values):
195 if idx in [1, 3, 5]:
196 self.assertTrue(flag['base_PixelFlags_flag'])
197 else:
198 self.assertFalse(flag['base_PixelFlags_flag'])
200 if idx in [1, 4, 6]:
201 self.assertTrue(flag['base_PixelFlags_flag_offimage'])
202 else:
203 self.assertFalse(flag['base_PixelFlags_flag_offimage'])
205 def test_flag_existence_check(self):
206 unpacker = UnpackApdbFlags(self.config.flagMap, "DiaSource")
208 self.assertTrue(unpacker.flagExists('base_PixelFlags_flag'))
209 self.assertFalse(unpacker.flagExists(''))
210 with self.assertRaisesRegex(ValueError, 'column doesNotExist not in flag map'):
211 unpacker.flagExists('base_PixelFlags_flag', columnName='doesNotExist')
213 def test_flag_bitmask(self):
214 """Test that we get the expected bitmask back from supplied flag names.
215 """
216 unpacker = UnpackApdbFlags(self.config.flagMap, "DiaSource")
218 with self.assertRaisesRegex(ValueError, "flag '' not included"):
219 unpacker.makeFlagBitMask([''])
220 with self.assertRaisesRegex(ValueError, 'column doesNotExist not in flag map'):
221 unpacker.makeFlagBitMask(['base_PixelFlags_flag'], columnName='doesNotExist')
222 self.assertEqual(unpacker.makeFlagBitMask(['base_PixelFlags_flag']), np.uint64(1))
223 self.assertEqual(unpacker.makeFlagBitMask(['base_PixelFlags_flag_offimage']), np.uint64(2))
224 self.assertEqual(unpacker.makeFlagBitMask(['base_PixelFlags_flag',
225 'base_PixelFlags_flag_offimage']),
226 np.uint64(3))
229class MemoryTester(lsst.utils.tests.MemoryTestCase):
230 pass
233def setup_module(module):
234 lsst.utils.tests.init()
237if __name__ == "__main__": 237 ↛ 238line 237 didn't jump to line 238, because the condition on line 237 was never true
238 lsst.utils.tests.init()
239 unittest.main()