Coverage for tests / test_diaPipe.py: 17%
257 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:09 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:09 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import contextlib
23import tempfile
24import unittest
25from unittest.mock import patch, Mock, MagicMock, DEFAULT
26import warnings
28import numpy as np
29import pandas as pd
30import astropy.table as tb
32import lsst.afw.image as afwImage
33import lsst.afw.table as afwTable
34import lsst.dax.apdb as daxApdb
35from lsst.meas.base import IdGenerator
36import lsst.pex.config as pexConfig
37import lsst.pipe.base as pipeBase
38import lsst.utils.tests
39from lsst.pipe.base.testUtils import assertValidOutput
41from lsst.ap.association import DiaPipelineTask
42from lsst.pipe.tasks.schemaUtils import convertDataFrameToSdmSchema
43from utils_tests import makeExposure, makeDiaObjects, makeDiaSources, makeDiaForcedSources, \
44 makeSolarSystemSources
47def _makeMockDataFrame():
48 """Create a new mock of a DataFrame.
50 Returns
51 -------
52 mock : `unittest.mock.Mock`
53 A mock guaranteed to accept all operations used by `pandas.DataFrame`.
54 """
55 with warnings.catch_warnings():
56 # spec triggers deprecation warnings on DataFrame, but will
57 # automatically adapt to any removals.
58 warnings.simplefilter("ignore", category=DeprecationWarning)
59 return MagicMock(spec=pd.DataFrame())
62def _makeMockTable():
63 """Create a new mock of a Table.
65 Returns
66 -------
67 mock : `unittest.mock.Mock`
68 A mock guaranteed to accept all operations used by `astropy.table.Table`.
69 """
70 with warnings.catch_warnings():
71 # spec triggers deprecation warnings on DataFrame, but will
72 # automatically adapt to any removals.
73 warnings.simplefilter("ignore", category=DeprecationWarning)
74 return MagicMock(spec=tb.Table())
77class TestDiaPipelineTask(unittest.TestCase):
79 @classmethod
80 def _makeDefaultConfig(cls, config_file, **kwargs):
81 config = DiaPipelineTask.ConfigClass()
82 config.apdb_config_url = config_file
83 config.update(**kwargs)
84 return config
86 def setUp(self):
87 # Create an instance of random generator with fixed seed.
88 rng = np.random.default_rng(1234)
89 self.rng = rng
91 # schemas are persisted in both Gen 2 and Gen 3 butler as prototypical catalogs
92 srcSchema = afwTable.SourceTable.makeMinimalSchema()
93 srcSchema.addField("base_PixelFlags_flag", type="Flag")
94 srcSchema.addField("base_PixelFlags_flag_offimage", type="Flag")
95 self.srcSchema = afwTable.SourceCatalog(srcSchema)
96 self.exposure = makeExposure(False, False)
97 self.diaObjects = makeDiaObjects(20, self.exposure, rng)
98 self.diaSources = makeDiaSources(
99 100, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)
100 self.diaForcedSources = makeDiaForcedSources(
101 200, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)
102 self.ssSources = makeSolarSystemSources(
103 20, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)
105 sqlite_file = tempfile.NamedTemporaryFile()
106 self.addCleanup(sqlite_file.close)
107 self.config_file = tempfile.NamedTemporaryFile()
108 self.addCleanup(self.config_file.close)
109 apdb_config = daxApdb.ApdbSql.init_database(db_url=f"sqlite:///{sqlite_file.name}")
110 apdb_config.save(self.config_file.name)
112 def testRun(self):
113 """Test running while creating and packaging alerts.
114 """
115 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=True, doReloadDiaObjects=False)
117 def testRunWithSolarSystemAssociation(self):
118 """Test running while creating and packaging alerts.
119 """
120 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=True, doReloadDiaObjects=False)
122 def testRunWithAlerts(self):
123 """Test running while creating and packaging alerts.
124 """
125 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=False, doReloadDiaObjects=False)
127 def testRunWithoutAlertsOrSolarSystem(self):
128 """Test running without creating and packaging alerts.
129 """
130 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=False, doReloadDiaObjects=False)
132 def testRunWithReload(self):
133 """Test running with reloading DiaObjects.
134 """
135 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=False, doReloadDiaObjects=True)
137 def testRunWithReloadAndSolarSystem(self):
138 """Test running with solar system association and reloading DiaObjects.
139 """
140 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=True, doReloadDiaObjects=True)
142 def testRunWithReloadAndAlerts(self):
143 """Test running with reloading DiaObjects while creating and packaging alerts.
144 """
145 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=False, doReloadDiaObjects=True)
147 def _testRun(self, doPackageAlerts=False, doSolarSystemAssociation=False, doRunForcedMeasurement=False,
148 doReloadDiaObjects=False):
149 """Test the normal workflow of each ap_pipe step.
150 """
151 config = self._makeDefaultConfig(
152 config_file=self.config_file.name,
153 doPackageAlerts=doPackageAlerts,
154 doSolarSystemAssociation=doSolarSystemAssociation,
155 doRunForcedMeasurement=doRunForcedMeasurement,
156 doReloadDiaObjects=doReloadDiaObjects,
157 )
158 task = DiaPipelineTask(config=config)
159 # Set DataFrame index testing to always return False. Mocks return
160 # true for this check otherwise.
161 task.testDataFrameIndex = lambda x: False
162 diffIm = Mock(spec=afwImage.ExposureF)
163 template = Mock(spec=afwImage.ExposureF)
164 diaSrc = _makeMockDataFrame()
165 ssObjects = _makeMockTable()
167 # Each of these subtasks should be called once during diaPipe
168 # execution. We use mocks here to check they are being executed
169 # appropriately.
170 subtasksToMock = [
171 "diaCalculation",
172 ]
173 if doPackageAlerts:
174 subtasksToMock.append("alertPackager")
175 else:
176 self.assertFalse(hasattr(task, "alertPackager"))
178 if not doSolarSystemAssociation:
179 self.assertFalse(hasattr(task, "solarSystemAssociator"))
181 def concatMock(_data, **_kwargs):
182 return _makeMockDataFrame()
184 # Mock out the run() methods of these Tasks to ensure they
185 # return data in the correct form.
186 def solarSystemAssociator_run(unAssocDiaSources, solarSystemObjectTable, visitInfo,
187 bbox, wcs):
188 return lsst.pipe.base.Struct(nTotalSsObjects=42,
189 nAssociatedSsObjects=30,
190 ssoAssocDiaSources=_makeMockTable(),
191 unAssocDiaSources=_makeMockTable(),
192 associatedSsSources=_makeMockTable(),
193 unassociatedSsObjects=_makeMockTable())
195 def associator_run(table, diaObjects):
196 return lsst.pipe.base.Struct(nUpdatedDiaObjects=2, nUnassociatedDiaObjects=3,
197 matchedDiaSources=_makeMockDataFrame(),
198 unAssocDiaSources=_makeMockDataFrame())
200 def loadObjects_run(region, preloadedDiaObjects):
201 task.metadata['loadRefreshedDiaObjectsStartUtc'] = 1.234
202 task.metadata['loadRefreshedDiaObjectsEndUtc'] = 5.678
203 return self.diaObjects
205 def updateObjectTableMock(diaObjects, diaSources):
206 pass
208 # apdb isn't a subtask, but still needs to be mocked out for correct
209 # execution in the test environment.
210 with patch.multiple(task, **{task: DEFAULT for task in subtasksToMock + ["apdb"]}), \
211 patch('lsst.ap.association.diaPipe.pd.concat', side_effect=concatMock), \
212 patch('lsst.ap.association.diaPipe.DiaPipelineTask.updateObjectTable',
213 side_effect=updateObjectTableMock), \
214 patch('lsst.ap.association.diaPipe.DiaPipelineTask.loadRefreshedDiaObjects',
215 side_effect=loadObjects_run), \
216 patch('lsst.ap.association.association.AssociationTask.run',
217 side_effect=associator_run) as mainRun, \
218 patch('lsst.pipe.tasks.ssoAssociation.SolarSystemAssociationTask.run',
219 side_effect=solarSystemAssociator_run) as ssRun:
221 result = task.run(diaSrc,
222 None,
223 diffIm,
224 self.exposure,
225 template,
226 preloadedDiaObjects=self.diaObjects,
227 preloadedDiaSources=self.diaSources,
228 preloadedDiaForcedSources=self.diaForcedSources,
229 band="g",
230 idGenerator=IdGenerator(),
231 solarSystemObjectTable=ssObjects)
232 for subtaskName in subtasksToMock:
233 getattr(task, subtaskName).run.assert_called_once()
234 assertValidOutput(task, result)
235 # Exact type and contents of apdbMarker are undefined.
236 self.assertIsInstance(result.apdbMarker, pexConfig.Config)
237 meta = task.getFullMetadata()
238 # Check that the expected metadata has been set.
239 self.assertEqual(meta["diaPipe.numUpdatedDiaObjects"], 2)
240 self.assertEqual(meta["diaPipe.numUnassociatedDiaObjects"], 3)
241 # and that associators ran once or not at all.
242 mainRun.assert_called_once()
243 if doSolarSystemAssociation:
244 ssRun.assert_called_once()
245 else:
246 ssRun.assert_not_called()
248 def test_tooManyDiaObjectsError(self):
249 maxNewDiaObjects = 100
251 nDiaSources = maxNewDiaObjects + 1
252 diaSources = makeDiaSources(nDiaSources, np.zeros(nDiaSources), self.exposure, self.rng)
254 def runAndTestWithContextManager(threshold):
255 config = self._makeDefaultConfig(config_file=self.config_file.name,
256 doSolarSystemAssociation=False,
257 filterUnAssociatedSources=False,
258 maxNewDiaObjects=threshold,
259 )
260 task = DiaPipelineTask(config=config)
261 contextManager = self.assertRaises(pipeBase.AlgorithmError) if nDiaSources > threshold > 0 \
262 else contextlib.nullcontext()
263 with contextManager:
264 task.associateDiaSources(
265 diaSources,
266 None,
267 None,
268 self.diaObjects,
269 )
270 # Test cases at, above, and below the threshold as well as at 0.
271 runAndTestWithContextManager(0)
272 runAndTestWithContextManager(maxNewDiaObjects - 1)
273 runAndTestWithContextManager(maxNewDiaObjects)
274 runAndTestWithContextManager(maxNewDiaObjects + 1)
276 def test_createDiaObjects(self):
277 """Test that creating new DiaObjects works as expected.
278 """
279 nSources = 5
280 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False)
281 task = DiaPipelineTask(config=config)
282 diaSources = pd.DataFrame(data=[
283 {"ra": 0.04*idx, "dec": 0.04*idx,
284 "diaSourceId": idx + 1 + nSources, "diaObjectId": 0,
285 "ssObjectId": 0}
286 for idx in range(nSources)])
288 result = task.createNewDiaObjects(convertDataFrameToSdmSchema(task.schema, diaSources, "DiaSource",
289 skipIndex=True))
290 self.assertEqual(nSources, len(result.newDiaObjects))
291 self.assertTrue(np.all(np.equal(
292 result.diaSources["diaObjectId"].to_numpy(),
293 result.diaSources["diaSourceId"].to_numpy())))
294 self.assertTrue(np.all(np.equal(
295 result.newDiaObjects["diaObjectId"].to_numpy(),
296 result.diaSources["diaSourceId"].to_numpy())))
298 def test_purgeDiaObjects(self):
299 """Remove diaOjects that are outside an image's bounding box.
300 """
302 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False)
303 task = DiaPipelineTask(config=config)
304 exposure = makeExposure(False, False)
305 nObj0 = 20
307 # Create diaObjects
308 diaObjects = makeDiaObjects(nObj0, exposure, self.rng)
309 # Shrink the bounding box so that some of the diaObjects will be outside
310 bbox = exposure.getBBox()
311 size = np.minimum(bbox.getHeight(), bbox.getWidth())
312 bbox.grow(-size//4)
313 exposureCut = exposure[bbox]
314 sizeCut = np.minimum(bbox.getHeight(), bbox.getWidth())
315 buffer = 10
316 bbox.grow(buffer)
318 def check_diaObjects(bbox, wcs, diaObjects):
319 raVals = diaObjects.ra.to_numpy()
320 decVals = diaObjects.dec.to_numpy()
321 xVals, yVals = wcs.skyToPixelArray(raVals, decVals, degrees=True)
322 selector = bbox.contains(xVals, yVals)
323 return selector
325 selector0 = check_diaObjects(bbox, exposureCut.getWcs(), diaObjects)
326 nIn0 = np.count_nonzero(selector0)
327 nOut0 = np.count_nonzero(~selector0)
328 self.assertEqual(nObj0, nIn0 + nOut0)
330 # Add an ID that is not in the diaObject table. It should not get removed.
331 diaObjectIds0 = diaObjects["diaObjectId"].copy(deep=True)
332 diaObjectIds0[max(diaObjectIds0.index) + 1] = 999
333 diaObjects1, objIds = task.purgeDiaObjects(exposureCut.getBBox(), exposureCut.getWcs(), diaObjects,
334 diaObjectIds=diaObjectIds0, buffer=buffer)
335 diaObjectIds1 = diaObjects1["diaObjectId"]
336 # Verify that the bounding box was not changed
337 sizeCheck = np.minimum(exposureCut.getBBox().getHeight(), exposureCut.getBBox().getWidth())
338 self.assertEqual(sizeCut, sizeCheck)
339 selector1 = check_diaObjects(bbox, exposureCut.getWcs(), diaObjects1)
340 nIn1 = np.count_nonzero(selector1)
341 nOut1 = np.count_nonzero(~selector1)
342 nObj1 = len(diaObjects1)
343 self.assertEqual(nObj1, nIn0)
344 # Verify that not all diaObjects were removed
345 self.assertGreater(nObj1, 0)
346 # Check that some diaObjects were removed
347 self.assertLess(nObj1, nObj0)
348 # Verify that no objects outside the bounding box remain
349 self.assertEqual(nOut1, 0)
350 # Verify that no objects inside the bounding box were removed
351 self.assertEqual(nIn1, nIn0)
352 # The length of the updated object IDs should equal the number of objects
353 # plus one, since we added an extra ID.
354 self.assertEqual(nObj1 + 1, len(objIds))
355 # All of the object IDs extracted from the catalog should be in the pruned object IDs
356 self.assertTrue(set(objIds).issuperset(diaObjectIds1))
357 # The pruned object IDs should contain entries that are not in the catalog
358 self.assertFalse(set(diaObjectIds1).issuperset(objIds))
359 # Some IDs should have been removed
360 self.assertLess(len(objIds), len(diaObjectIds0))
362 def test_filterDiaObjects(self):
363 """Unassociated diaSources that are filtered should have good reliability and SNR.
364 Glint trail sources should also be filtered out.
365 """
367 config = self._makeDefaultConfig(config_file=self.config_file.name,
368 doPackageAlerts=False,
369 filterUnAssociatedSources=True)
371 configBadFilter = self._makeDefaultConfig(config_file=self.config_file.name,
372 doPackageAlerts=False,
373 filterUnAssociatedSources=True,
374 newObjectFluxField="notAFlux")
375 configBadFlag = self._makeDefaultConfig(config_file=self.config_file.name,
376 doPackageAlerts=False,
377 filterUnAssociatedSources=True,
378 newObjectBadFlags=("junkSource", "notUsed"))
379 with self.assertRaises(pipeBase.InvalidQuantumError):
380 DiaPipelineTask(config=configBadFilter)
381 with self.assertRaises(pipeBase.InvalidQuantumError):
382 DiaPipelineTask(config=configBadFlag)
383 task = DiaPipelineTask(config=config)
384 nUnassociatedDiaSources = 234
386 # Create diaSources
387 diaSources = makeDiaSources(nUnassociatedDiaSources,
388 np.zeros(nUnassociatedDiaSources),
389 self.exposure,
390 self.rng,
391 flagList=task.config.newObjectBadFlags)
392 reliability = self.rng.random(nUnassociatedDiaSources)
393 flux = (self.rng.random(nUnassociatedDiaSources)**2)*100
394 fluxErr = np.sqrt(flux)
395 glint_trail = np.zeros(nUnassociatedDiaSources, dtype=bool)
396 glint_trail[12:16] = True # add 4 glint trail sources
397 diaSources["reliability"] = reliability
398 diaSources[config.newObjectFluxField] = flux
399 diaSources[config.newObjectFluxField + "Err"] = fluxErr
400 diaSources["glint_trail"] = glint_trail
401 badFlagName = task.config.newObjectBadFlags[0]
402 badFlags = np.zeros(nUnassociatedDiaSources, dtype=bool)
403 nBadFlags = 20
404 badFlags[0:nBadFlags] = True
405 diaSources[badFlagName] = badFlags
407 def runAndCheckFilter(diaSources, snrThreshold=None, lowReliabilitySnrThreshold=None,
408 reliabilityThreshold=None, lowSnrReliabilityThreshold=None,
409 badFlags=None,
410 ):
412 filterResults = task.filterSources(
413 diaSources.copy(deep=True),
414 snrThreshold=snrThreshold,
415 lowReliabilitySnrThreshold=lowReliabilitySnrThreshold,
416 reliabilityThreshold=reliabilityThreshold,
417 lowSnrReliabilityThreshold=lowSnrReliabilityThreshold,
418 badFlags=badFlags,
419 )
420 self.assertEqual(len(filterResults.goodSources) + len(filterResults.badSources),
421 nUnassociatedDiaSources)
422 goodFlux = filterResults.goodSources[config.newObjectFluxField]
423 goodFluxErr = filterResults.goodSources[config.newObjectFluxField + "Err"]
424 goodSnr = np.array(goodFlux/goodFluxErr)
425 self.assertTrue(np.all(goodSnr > snrThreshold))
426 goodReliability = np.array(filterResults.goodSources["reliability"])
427 self.assertTrue(np.all(goodReliability > reliabilityThreshold))
428 goodLowSnrFlag = goodSnr < lowReliabilitySnrThreshold
429 lowSnrReliability = goodReliability[goodLowSnrFlag]
430 self.assertTrue(np.all(lowSnrReliability > lowSnrReliabilityThreshold))
431 glintTrailSources = np.array(filterResults.goodSources["glint_trail"])
432 self.assertTrue(not any(glintTrailSources))
434 # No sources should be removed if the thresholds are turned off
435 runAndCheckFilter(diaSources,
436 snrThreshold=0, lowReliabilitySnrThreshold=0,
437 reliabilityThreshold=0, lowSnrReliabilityThreshold=0)
438 runAndCheckFilter(diaSources,
439 snrThreshold=0, lowReliabilitySnrThreshold=0,
440 reliabilityThreshold=0, lowSnrReliabilityThreshold=0,
441 badFlags=[badFlagName])
442 runAndCheckFilter(diaSources,
443 snrThreshold=2, lowReliabilitySnrThreshold=8,
444 reliabilityThreshold=0, lowSnrReliabilityThreshold=0)
445 runAndCheckFilter(diaSources,
446 snrThreshold=2, lowReliabilitySnrThreshold=8,
447 reliabilityThreshold=0, lowSnrReliabilityThreshold=0.5)
448 runAndCheckFilter(diaSources,
449 snrThreshold=0, lowReliabilitySnrThreshold=0,
450 reliabilityThreshold=0.1, lowSnrReliabilityThreshold=0.5)
451 runAndCheckFilter(diaSources,
452 snrThreshold=2, lowReliabilitySnrThreshold=8,
453 reliabilityThreshold=0.1, lowSnrReliabilityThreshold=0.5)
454 runAndCheckFilter(diaSources,
455 snrThreshold=2, lowReliabilitySnrThreshold=8,
456 reliabilityThreshold=0.1, lowSnrReliabilityThreshold=0.5,
457 badFlags=[badFlagName])
459 def test_mergeEmptyCatalog(self):
460 """Test that a catalog is unchanged if it is merged with an empty
461 catalog.
462 """
463 diaSourcesBase = self.diaSources
465 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False)
466 task = DiaPipelineTask(config=config)
467 # Include some but not all columns that should be in diaSourcesBase, and some that are mis-matched
468 diaSourcesEmpty = pd.DataFrame(columns=["ra", "dec", "foo"])
469 diaSourcesTest = task.mergeCatalogs(diaSourcesBase, diaSourcesEmpty, tableName="DiaSource")
470 self.assertTrue(diaSourcesBase.equals(diaSourcesTest))
472 def test_mergeCatalogs(self):
473 """Test that a merged catalog is concatenated correctly.
474 """
475 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False)
476 task = DiaPipelineTask(config=config)
478 diaSourcesBase = convertDataFrameToSdmSchema(task.schema, self.diaSources, "DiaSource",
479 skipIndex=True)
480 nBase = len(diaSourcesBase)
481 nNew = int(nBase/2)
483 diaSourcesNew = makeDiaSources(nNew, self.diaObjects["diaObjectId"].to_numpy(), self.exposure,
484 self.rng)
485 diaSourcesNew = convertDataFrameToSdmSchema(task.schema, diaSourcesNew, "DiaSource", skipIndex=True)
486 diaSourcesTest = task.mergeCatalogs(diaSourcesBase, diaSourcesNew, tableName="DiaSource")
487 self.assertEqual(len(diaSourcesTest), nBase + nNew)
488 diaSourcesExtract1 = diaSourcesTest.iloc[:nBase]
489 diaSourcesExtract2 = diaSourcesTest.iloc[nBase:]
491 pd.testing.assert_frame_equal(diaSourcesBase, diaSourcesExtract1)
492 pd.testing.assert_frame_equal(diaSourcesNew, diaSourcesExtract2)
494 def test_updateObjectTable(self):
495 """Test that the diaObject record is updated with the number of
496 diaSources.
497 """
498 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False)
499 task = DiaPipelineTask(config=config)
500 nObjects = 20
501 nSrcPerObject = 10
502 nExtraSources = 5
503 nSources = nSrcPerObject*nObjects + nExtraSources
504 expectedSourcesPerObject = nSrcPerObject*np.ones(nObjects)
505 expectedSourcesPerObject[:nExtraSources] += 1
506 diaObjects = makeDiaObjects(nObjects, self.exposure, self.rng)
507 diaSources = makeDiaSources(nSources, diaObjects["diaObjectId"].to_numpy(), self.exposure, self.rng)
508 updatedDiaObjects = task.updateObjectTable(diaObjects, diaSources)
509 self.assertTrue(np.all(updatedDiaObjects.nDiaSources.values == expectedSourcesPerObject))
512class MemoryTester(lsst.utils.tests.MemoryTestCase):
513 pass
516def setup_module(module):
517 lsst.utils.tests.init()
520if __name__ == "__main__": 520 ↛ 521line 520 didn't jump to line 521 because the condition on line 520 was never true
521 lsst.utils.tests.init()
522 unittest.main()