Coverage for tests/test_diaPipe.py: 20%
157 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 09:33 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 09:33 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import tempfile
23import unittest
24from unittest.mock import patch, Mock, MagicMock, DEFAULT
25import warnings
27import numpy as np
28import pandas as pd
30import lsst.afw.image as afwImage
31import lsst.afw.table as afwTable
32import lsst.dax.apdb as daxApdb
33import lsst.pex.config as pexConfig
34import lsst.utils.tests
35from lsst.pipe.base.testUtils import assertValidOutput
37from lsst.ap.association import DiaPipelineTask
38from utils_tests import makeExposure, makeDiaObjects
41def _makeMockDataFrame():
42 """Create a new mock of a DataFrame.
44 Returns
45 -------
46 mock : `unittest.mock.Mock`
47 A mock guaranteed to accept all operations used by `pandas.DataFrame`.
48 """
49 with warnings.catch_warnings():
50 # spec triggers deprecation warnings on DataFrame, but will
51 # automatically adapt to any removals.
52 warnings.simplefilter("ignore", category=DeprecationWarning)
53 return MagicMock(spec=pd.DataFrame())
56class TestDiaPipelineTask(unittest.TestCase):
58 @classmethod
59 def _makeDefaultConfig(cls,
60 config_file,
61 doPackageAlerts=False,
62 doSolarSystemAssociation=False):
63 config = DiaPipelineTask.ConfigClass()
64 config.doConfigureApdb = False
65 config.apdb_config_url = config_file
66 config.doPackageAlerts = doPackageAlerts
67 config.doSolarSystemAssociation = doSolarSystemAssociation
68 return config
70 def setUp(self):
71 # Create an instance of random generator with fixed seed.
72 rng = np.random.default_rng(1234)
73 self.rng = rng
75 # schemas are persisted in both Gen 2 and Gen 3 butler as prototypical catalogs
76 srcSchema = afwTable.SourceTable.makeMinimalSchema()
77 srcSchema.addField("base_PixelFlags_flag", type="Flag")
78 srcSchema.addField("base_PixelFlags_flag_offimage", type="Flag")
79 self.srcSchema = afwTable.SourceCatalog(srcSchema)
81 apdb_config = daxApdb.ApdbSql.init_database(db_url="sqlite://")
82 self.config_file = tempfile.NamedTemporaryFile()
83 self.addCleanup(self.config_file.close)
84 apdb_config.save(self.config_file.name)
86 # TODO: remove on DM-43419
87 def testConfigApdbNestedOk(self):
88 config = DiaPipelineTask.ConfigClass()
89 config.doConfigureApdb = True
90 with self.assertWarns(FutureWarning):
91 config.apdb.db_url = "sqlite://"
92 config.freeze()
93 config.validate()
95 # TODO: remove on DM-43419
96 def testConfigApdbNestedInvalid(self):
97 config = DiaPipelineTask.ConfigClass()
98 config.doConfigureApdb = True
99 # Don't set db_url
100 config.freeze()
101 with self.assertRaises(pexConfig.FieldValidationError):
102 config.validate()
104 # TODO: remove on DM-43419
105 def testConfigApdbFileOk(self):
106 config = DiaPipelineTask.ConfigClass()
107 config.doConfigureApdb = False
108 config.apdb_config_url = "some/file/path.yaml"
109 config.freeze()
110 config.validate()
112 # TODO: remove on DM-43419
113 def testConfigApdbFileInvalid(self):
114 config = DiaPipelineTask.ConfigClass()
115 config.doConfigureApdb = False
116 # Don't set apdb_config_url
117 config.freeze()
118 with self.assertRaises(pexConfig.FieldValidationError):
119 config.validate()
121 def testRun(self):
122 """Test running while creating and packaging alerts.
123 """
124 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=True)
126 def testRunWithSolarSystemAssociation(self):
127 """Test running while creating and packaging alerts.
128 """
129 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=True)
131 def testRunWithAlerts(self):
132 """Test running while creating and packaging alerts.
133 """
134 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=False)
136 def testRunWithoutAlertsOrSolarSystem(self):
137 """Test running without creating and packaging alerts.
138 """
139 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=False)
141 def _testRun(self, doPackageAlerts=False, doSolarSystemAssociation=False):
142 """Test the normal workflow of each ap_pipe step.
143 """
144 config = self._makeDefaultConfig(
145 config_file=self.config_file.name,
146 doPackageAlerts=doPackageAlerts,
147 doSolarSystemAssociation=doSolarSystemAssociation)
148 task = DiaPipelineTask(config=config)
149 # Set DataFrame index testing to always return False. Mocks return
150 # true for this check otherwise.
151 task.testDataFrameIndex = lambda x: False
152 diffIm = Mock(spec=afwImage.ExposureF)
153 exposure = Mock(spec=afwImage.ExposureF)
154 template = Mock(spec=afwImage.ExposureF)
155 diaSrc = _makeMockDataFrame()
156 ssObjects = _makeMockDataFrame()
157 ccdExposureIdBits = 32
159 # Each of these subtasks should be called once during diaPipe
160 # execution. We use mocks here to check they are being executed
161 # appropriately.
162 subtasksToMock = [
163 "diaCatalogLoader",
164 "diaCalculation",
165 "diaForcedSource",
166 ]
167 if doPackageAlerts:
168 subtasksToMock.append("alertPackager")
169 else:
170 self.assertFalse(hasattr(task, "alertPackager"))
172 if not doSolarSystemAssociation:
173 self.assertFalse(hasattr(task, "solarSystemAssociator"))
175 def concatMock(_data, **_kwargs):
176 return _makeMockDataFrame()
178 # Mock out the run() methods of these two Tasks to ensure they
179 # return data in the correct form.
180 def solarSystemAssociator_run(unAssocDiaSources, solarSystemObjectTable, diffIm):
181 return lsst.pipe.base.Struct(nTotalSsObjects=42,
182 nAssociatedSsObjects=30,
183 ssoAssocDiaSources=_makeMockDataFrame(),
184 unAssocDiaSources=_makeMockDataFrame())
186 def associator_run(table, diaObjects):
187 return lsst.pipe.base.Struct(nUpdatedDiaObjects=2, nUnassociatedDiaObjects=3,
188 matchedDiaSources=_makeMockDataFrame(),
189 unAssocDiaSources=_makeMockDataFrame())
191 # apdb isn't a subtask, but still needs to be mocked out for correct
192 # execution in the test environment.
193 with patch.multiple(task, **{task: DEFAULT for task in subtasksToMock + ["apdb"]}), \
194 patch('lsst.ap.association.diaPipe.pd.concat', side_effect=concatMock), \
195 patch('lsst.ap.association.association.AssociationTask.run',
196 side_effect=associator_run) as mainRun, \
197 patch('lsst.ap.association.ssoAssociation.SolarSystemAssociationTask.run',
198 side_effect=solarSystemAssociator_run) as ssRun:
200 result = task.run(diaSrc,
201 ssObjects,
202 diffIm,
203 exposure,
204 template,
205 ccdExposureIdBits,
206 "g")
207 for subtaskName in subtasksToMock:
208 getattr(task, subtaskName).run.assert_called_once()
209 assertValidOutput(task, result)
210 # Exact type and contents of apdbMarker are undefined.
211 self.assertIsInstance(result.apdbMarker, pexConfig.Config)
212 meta = task.getFullMetadata()
213 # Check that the expected metadata has been set.
214 self.assertEqual(meta["diaPipe.numUpdatedDiaObjects"], 2)
215 self.assertEqual(meta["diaPipe.numUnassociatedDiaObjects"], 3)
216 # and that associators ran once or not at all.
217 mainRun.assert_called_once()
218 if doSolarSystemAssociation:
219 ssRun.assert_called_once()
220 else:
221 ssRun.assert_not_called()
223 def test_createDiaObjects(self):
224 """Test that creating new DiaObjects works as expected.
225 """
226 nSources = 5
227 diaSources = pd.DataFrame(data=[
228 {"ra": 0.04*idx, "dec": 0.04*idx,
229 "diaSourceId": idx + 1 + nSources, "diaObjectId": 0,
230 "ssObjectId": 0}
231 for idx in range(nSources)])
233 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False)
234 task = DiaPipelineTask(config=config)
235 result = task.createNewDiaObjects(diaSources)
236 self.assertEqual(nSources, len(result.newDiaObjects))
237 self.assertTrue(np.all(np.equal(
238 result.diaSources["diaObjectId"].to_numpy(),
239 result.diaSources["diaSourceId"].to_numpy())))
240 self.assertTrue(np.all(np.equal(
241 result.newDiaObjects["diaObjectId"].to_numpy(),
242 result.diaSources["diaSourceId"].to_numpy())))
244 def test_purgeDiaObjects(self):
245 """Remove diaOjects that are outside an image's bounding box.
246 """
248 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False)
249 task = DiaPipelineTask(config=config)
250 exposure = makeExposure(False, False)
251 nObj0 = 20
253 # Create diaObjects
254 diaObjects = makeDiaObjects(nObj0, exposure, self.rng)
255 # Shrink the bounding box so that some of the diaObjects will be outside
256 bbox = exposure.getBBox()
257 size = np.minimum(bbox.getHeight(), bbox.getWidth())
258 bbox.grow(-size//4)
259 exposureCut = exposure[bbox]
260 sizeCut = np.minimum(bbox.getHeight(), bbox.getWidth())
261 buffer = 10
262 bbox.grow(buffer)
264 def check_diaObjects(bbox, wcs, diaObjects):
265 raVals = diaObjects.ra.to_numpy()
266 decVals = diaObjects.dec.to_numpy()
267 xVals, yVals = wcs.skyToPixelArray(raVals, decVals, degrees=True)
268 selector = bbox.contains(xVals, yVals)
269 return selector
271 selector0 = check_diaObjects(bbox, exposureCut.getWcs(), diaObjects)
272 nIn0 = np.count_nonzero(selector0)
273 nOut0 = np.count_nonzero(~selector0)
274 self.assertEqual(nObj0, nIn0 + nOut0)
276 diaObjects1 = task.purgeDiaObjects(exposureCut.getBBox(), exposureCut.getWcs(), diaObjects,
277 buffer=buffer)
278 # Verify that the bounding box was not changed
279 sizeCheck = np.minimum(exposureCut.getBBox().getHeight(), exposureCut.getBBox().getWidth())
280 self.assertEqual(sizeCut, sizeCheck)
281 selector1 = check_diaObjects(bbox, exposureCut.getWcs(), diaObjects1)
282 nIn1 = np.count_nonzero(selector1)
283 nOut1 = np.count_nonzero(~selector1)
284 nObj1 = len(diaObjects1)
285 self.assertEqual(nObj1, nIn0)
286 # Verify that not all diaObjects were removed
287 self.assertGreater(nObj1, 0)
288 # Check that some diaObjects were removed
289 self.assertLess(nObj1, nObj0)
290 # Verify that no objects outside the bounding box remain
291 self.assertEqual(nOut1, 0)
292 # Verify that no objects inside the bounding box were removed
293 self.assertEqual(nIn1, nIn0)
296class MemoryTester(lsst.utils.tests.MemoryTestCase):
297 pass
300def setup_module(module):
301 lsst.utils.tests.init()
304if __name__ == "__main__": 304 ↛ 305line 304 didn't jump to line 305, because the condition on line 304 was never true
305 lsst.utils.tests.init()
306 unittest.main()