Coverage for tests / test_diaPipe.py: 13%
350 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 09:05 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 09:05 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import contextlib
23import tempfile
24import unittest
25from unittest.mock import patch, MagicMock, DEFAULT
26import warnings
28import numpy as np
29import pandas as pd
30import astropy.table as tb
32import lsst.afw.table as afwTable
33import lsst.dax.apdb as daxApdb
34from lsst.meas.base import IdGenerator
35import lsst.pex.config as pexConfig
36import lsst.pipe.base as pipeBase
37import lsst.utils.tests
38from lsst.pipe.base.testUtils import assertValidOutput
40from lsst.ap.association import DiaPipelineTask
41from lsst.pipe.tasks.schemaUtils import convertDataFrameToSdmSchema
42from utils_tests import makeExposure, makeDiaObjects, makeDiaSources, makeDiaForcedSources, \
43 makeSolarSystemSources
46def _makeMockDataFrame():
47 """Create a new mock of a DataFrame.
49 Returns
50 -------
51 mock : `unittest.mock.Mock`
52 A mock guaranteed to accept all operations used by `pandas.DataFrame`.
53 """
54 with warnings.catch_warnings():
55 # spec triggers deprecation warnings on DataFrame, but will
56 # automatically adapt to any removals.
57 warnings.simplefilter("ignore", category=DeprecationWarning)
58 return MagicMock(spec=pd.DataFrame())
61def _makeMockTable():
62 """Create a new mock of a Table.
64 Returns
65 -------
66 mock : `unittest.mock.Mock`
67 A mock guaranteed to accept all operations used by `astropy.table.Table`.
68 """
69 with warnings.catch_warnings():
70 # spec triggers deprecation warnings on DataFrame, but will
71 # automatically adapt to any removals.
72 warnings.simplefilter("ignore", category=DeprecationWarning)
73 return MagicMock(spec=tb.Table())
76class TestDiaPipelineTask(unittest.TestCase):
78 @classmethod
79 def _makeDefaultConfig(cls, config_file, **kwargs):
80 config = DiaPipelineTask.ConfigClass()
81 config.apdb_config_url = config_file
82 config.update(**kwargs)
83 return config
85 def setUp(self):
86 # Create an instance of random generator with fixed seed.
87 rng = np.random.default_rng(1234)
88 self.rng = rng
90 # schemas are persisted in both Gen 2 and Gen 3 butler as prototypical catalogs
91 srcSchema = afwTable.SourceTable.makeMinimalSchema()
92 srcSchema.addField("base_PixelFlags_flag", type="Flag")
93 srcSchema.addField("base_PixelFlags_flag_offimage", type="Flag")
94 self.srcSchema = afwTable.SourceCatalog(srcSchema)
95 self.exposure = makeExposure(False, False)
96 self.diffim = makeExposure(False, False)
97 self.template = makeExposure(False, False)
98 self.diaObjects = makeDiaObjects(20, self.exposure, rng)
99 self.diaSources = makeDiaSources(
100 100, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)
101 self.diaForcedSources = makeDiaForcedSources(
102 200, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)
103 self.ssSources = makeSolarSystemSources(
104 20, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)
106 sqlite_file = tempfile.NamedTemporaryFile()
107 self.addCleanup(sqlite_file.close)
108 self.config_file = tempfile.NamedTemporaryFile()
109 self.addCleanup(self.config_file.close)
110 apdb_config = daxApdb.ApdbSql.init_database(db_url=f"sqlite:///{sqlite_file.name}")
111 apdb_config.save(self.config_file.name)
113 def testRun(self):
114 """Test running while creating and packaging alerts.
115 """
116 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=True, doReloadDiaObjects=False)
118 def testRunWithSolarSystemAssociation(self):
119 """Test running while creating and packaging alerts.
120 """
121 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=True, doReloadDiaObjects=False)
123 def testRunWithAlerts(self):
124 """Test running while creating and packaging alerts.
125 """
126 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=False, doReloadDiaObjects=False)
128 def testRunWithoutAlertsOrSolarSystem(self):
129 """Test running without creating and packaging alerts.
130 """
131 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=False, doReloadDiaObjects=False)
133 def testRunWithReload(self):
134 """Test running with reloading DiaObjects.
135 """
136 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=False, doReloadDiaObjects=True)
138 def testRunWithReloadAndSolarSystem(self):
139 """Test running with solar system association and reloading DiaObjects.
140 """
141 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=True, doReloadDiaObjects=True)
143 def testRunWithReloadAndAlerts(self):
144 """Test running with reloading DiaObjects while creating and packaging alerts.
145 """
146 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=False, doReloadDiaObjects=True)
148 def testRunDisableDeprecatedDoRunForcedMeasurement(self):
149 """Test running with forced sources disabled.
150 """
151 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=False, doReloadDiaObjects=True,
152 doRunForcedMeasurement=False, subtasksToMock=["diaCalculation", ]
153 )
155 def _testRun(self, doPackageAlerts=False, doSolarSystemAssociation=False,
156 doReloadDiaObjects=False, subtasksToMock=None, **kwargs):
157 """Test the normal workflow of each ap_pipe step.
158 """
159 config = self._makeDefaultConfig(
160 config_file=self.config_file.name,
161 doPackageAlerts=doPackageAlerts,
162 doSolarSystemAssociation=doSolarSystemAssociation,
163 doReloadDiaObjects=doReloadDiaObjects,
164 **kwargs
165 )
166 task = DiaPipelineTask(config=config)
167 # Set DataFrame index testing to always return False. Mocks return
168 # true for this check otherwise.
169 task.testDataFrameIndex = lambda x: False
170 diaSrc = _makeMockDataFrame()
171 ssObjects = _makeMockTable()
173 # Each of these subtasks should be called once during diaPipe
174 # execution. We use mocks here to check they are being executed
175 # appropriately.
176 if subtasksToMock is None:
177 subtasksToMock = ["diaCalculation", "diaForcedSource", ]
178 if doPackageAlerts:
179 subtasksToMock.append("alertPackager")
180 else:
181 self.assertFalse(hasattr(task, "alertPackager"))
183 if not doSolarSystemAssociation:
184 self.assertFalse(hasattr(task, "solarSystemAssociator"))
186 def concatMock(_data, **_kwargs):
187 return _makeMockDataFrame()
189 # Mock out the run() methods of these Tasks to ensure they
190 # return data in the correct form.
191 def solarSystemAssociator_run(unAssocDiaSources, solarSystemObjectTable, visitInfo,
192 bbox, wcs):
193 return lsst.pipe.base.Struct(nTotalSsObjects=42,
194 nAssociatedSsObjects=30,
195 ssoAssocDiaSources=_makeMockTable(),
196 unAssocDiaSources=_makeMockTable(),
197 associatedSsSources=_makeMockTable(),
198 unassociatedSsObjects=_makeMockTable())
200 def associator_run(table, diaObjects):
201 return lsst.pipe.base.Struct(nUpdatedDiaObjects=2, nUnassociatedDiaObjects=3,
202 matchedDiaSources=_makeMockDataFrame(),
203 unAssocDiaSources=_makeMockDataFrame())
205 def loadObjects_run(region, preloadedDiaObjects):
206 task.metadata['loadRefreshedDiaObjectsStartUtc'] = 1.234
207 task.metadata['loadRefreshedDiaObjectsEndUtc'] = 5.678
208 return self.diaObjects
210 def updateObjectTableMock(diaObjects, diaSources):
211 pass
213 def _selectGoodDiaObjects(diaObjectCat, mergedDiaSourceHistory):
214 return diaObjectCat.copy(deep=True)
216 # apdb isn't a subtask, but still needs to be mocked out for correct
217 # execution in the test environment.
218 with patch.multiple(task, **{task: DEFAULT for task in subtasksToMock + ["apdb"]}), \
219 patch('lsst.ap.association.diaPipe.pd.concat', side_effect=concatMock), \
220 patch('lsst.ap.association.diaPipe.DiaPipelineTask.updateObjectTable',
221 side_effect=updateObjectTableMock), \
222 patch('lsst.ap.association.diaPipe.DiaPipelineTask._selectGoodDiaObjects',
223 side_effect=_selectGoodDiaObjects), \
224 patch('lsst.ap.association.diaPipe.DiaPipelineTask.loadRefreshedDiaObjects',
225 side_effect=loadObjects_run), \
226 patch('lsst.ap.association.association.AssociationTask.run',
227 side_effect=associator_run) as mainRun, \
228 patch('lsst.pipe.tasks.ssoAssociation.SolarSystemAssociationTask.run',
229 side_effect=solarSystemAssociator_run) as ssRun:
231 result = task.run(diaSrc,
232 None,
233 self.diffim,
234 self.exposure,
235 self.template,
236 preloadedDiaObjects=self.diaObjects,
237 preloadedDiaSources=self.diaSources,
238 preloadedDiaForcedSources=self.diaForcedSources,
239 band="g",
240 idGenerator=IdGenerator(),
241 solarSystemObjectTable=ssObjects)
242 for subtaskName in subtasksToMock:
243 getattr(task, subtaskName).run.assert_called_once()
244 assertValidOutput(task, result)
245 # Exact type and contents of apdbMarker are undefined.
246 self.assertIsInstance(result.apdbMarker, pexConfig.Config)
247 meta = task.getFullMetadata()
248 # Check that the expected metadata has been set.
249 self.assertEqual(meta["diaPipe.numUpdatedDiaObjects"], 2)
250 self.assertEqual(meta["diaPipe.numUnassociatedDiaObjects"], 3)
251 # and that associators ran once or not at all.
252 mainRun.assert_called_once()
253 if doSolarSystemAssociation:
254 ssRun.assert_called_once()
255 else:
256 ssRun.assert_not_called()
258 def test_tooManyDiaObjectsError(self):
259 maxNewDiaObjects = 100
261 nDiaSources = maxNewDiaObjects + 1
262 diaSources = makeDiaSources(nDiaSources, np.zeros(nDiaSources), self.exposure, self.rng)
264 def runAndTestWithContextManager(threshold):
265 config = self._makeDefaultConfig(config_file=self.config_file.name,
266 doSolarSystemAssociation=False,
267 filterUnAssociatedSources=False,
268 maxNewDiaObjects=threshold,
269 )
270 task = DiaPipelineTask(config=config)
271 contextManager = self.assertRaises(pipeBase.AlgorithmError) if nDiaSources > threshold > 0 \
272 else contextlib.nullcontext()
273 with contextManager:
274 task.associateDiaSources(
275 diaSources,
276 None,
277 None,
278 self.diaObjects,
279 )
280 # Test cases at, above, and below the threshold as well as at 0.
281 runAndTestWithContextManager(0)
282 runAndTestWithContextManager(maxNewDiaObjects - 1)
283 runAndTestWithContextManager(maxNewDiaObjects)
284 runAndTestWithContextManager(maxNewDiaObjects + 1)
286 def test_createDiaObjects(self):
287 """Test that creating new DiaObjects works as expected.
288 """
289 nSources = 5
290 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False)
291 task = DiaPipelineTask(config=config)
292 diaSources = pd.DataFrame(data=[
293 {"ra": 0.04*idx, "dec": 0.04*idx,
294 "diaSourceId": idx + 1 + nSources, "diaObjectId": 0,
295 "ssObjectId": 0}
296 for idx in range(nSources)])
298 result = task.createNewDiaObjects(convertDataFrameToSdmSchema(task.schema, diaSources, "DiaSource",
299 skipIndex=True))
300 self.assertEqual(nSources, len(result.newDiaObjects))
301 self.assertTrue(np.all(np.equal(
302 result.diaSources["diaObjectId"].to_numpy(),
303 result.diaSources["diaSourceId"].to_numpy())))
304 self.assertTrue(np.all(np.equal(
305 result.newDiaObjects["diaObjectId"].to_numpy(),
306 result.diaSources["diaSourceId"].to_numpy())))
308 def test_purgeDiaObjects(self):
309 """Remove diaOjects that are outside an image's bounding box.
310 """
312 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False)
313 task = DiaPipelineTask(config=config)
314 exposure = makeExposure(False, False)
315 nObj0 = 20
317 # Create diaObjects
318 diaObjects = makeDiaObjects(nObj0, exposure, self.rng)
319 # Shrink the bounding box so that some of the diaObjects will be outside
320 bbox = exposure.getBBox()
321 size = np.minimum(bbox.getHeight(), bbox.getWidth())
322 bbox.grow(-size//4)
323 exposureCut = exposure[bbox]
324 sizeCut = np.minimum(bbox.getHeight(), bbox.getWidth())
325 buffer = 10
326 bbox.grow(buffer)
328 def check_diaObjects(bbox, wcs, diaObjects):
329 raVals = diaObjects.ra.to_numpy()
330 decVals = diaObjects.dec.to_numpy()
331 xVals, yVals = wcs.skyToPixelArray(raVals, decVals, degrees=True)
332 selector = bbox.contains(xVals, yVals)
333 return selector
335 selector0 = check_diaObjects(bbox, exposureCut.getWcs(), diaObjects)
336 nIn0 = np.count_nonzero(selector0)
337 nOut0 = np.count_nonzero(~selector0)
338 self.assertEqual(nObj0, nIn0 + nOut0)
340 # Add an ID that is not in the diaObject table. It should not get removed.
341 diaObjectIds0 = diaObjects["diaObjectId"].copy(deep=True)
342 diaObjectIds0[max(diaObjectIds0.index) + 1] = 999
343 diaObjects1, objIds = task.purgeDiaObjects(exposureCut.getBBox(), exposureCut.getWcs(), diaObjects,
344 diaObjectIds=diaObjectIds0, buffer=buffer)
345 diaObjectIds1 = diaObjects1["diaObjectId"]
346 # Verify that the bounding box was not changed
347 sizeCheck = np.minimum(exposureCut.getBBox().getHeight(), exposureCut.getBBox().getWidth())
348 self.assertEqual(sizeCut, sizeCheck)
349 selector1 = check_diaObjects(bbox, exposureCut.getWcs(), diaObjects1)
350 nIn1 = np.count_nonzero(selector1)
351 nOut1 = np.count_nonzero(~selector1)
352 nObj1 = len(diaObjects1)
353 self.assertEqual(nObj1, nIn0)
354 # Verify that not all diaObjects were removed
355 self.assertGreater(nObj1, 0)
356 # Check that some diaObjects were removed
357 self.assertLess(nObj1, nObj0)
358 # Verify that no objects outside the bounding box remain
359 self.assertEqual(nOut1, 0)
360 # Verify that no objects inside the bounding box were removed
361 self.assertEqual(nIn1, nIn0)
362 # The length of the updated object IDs should equal the number of objects
363 # plus one, since we added an extra ID.
364 self.assertEqual(nObj1 + 1, len(objIds))
365 # All of the object IDs extracted from the catalog should be in the pruned object IDs
366 self.assertTrue(set(objIds).issuperset(diaObjectIds1))
367 # The pruned object IDs should contain entries that are not in the catalog
368 self.assertFalse(set(diaObjectIds1).issuperset(objIds))
369 # Some IDs should have been removed
370 self.assertLess(len(objIds), len(diaObjectIds0))
372 def test_filterDiaObjects(self):
373 """Unassociated diaSources that are filtered should have good reliability and SNR.
374 Glint trail sources should also be filtered out.
375 """
377 config = self._makeDefaultConfig(config_file=self.config_file.name,
378 doPackageAlerts=False,
379 filterUnAssociatedSources=True)
381 configBadFilter = self._makeDefaultConfig(config_file=self.config_file.name,
382 doPackageAlerts=False,
383 filterUnAssociatedSources=True,
384 newObjectFluxField="notAFlux")
385 configBadFlag = self._makeDefaultConfig(config_file=self.config_file.name,
386 doPackageAlerts=False,
387 filterUnAssociatedSources=True,
388 newObjectBadFlags=("junkSource", "notUsed"))
389 with self.assertRaises(pipeBase.InvalidQuantumError):
390 DiaPipelineTask(config=configBadFilter)
391 with self.assertRaises(pipeBase.InvalidQuantumError):
392 DiaPipelineTask(config=configBadFlag)
393 task = DiaPipelineTask(config=config)
394 nUnassociatedDiaSources = 234
396 # Create diaSources
397 diaSources = makeDiaSources(nUnassociatedDiaSources,
398 np.zeros(nUnassociatedDiaSources),
399 self.exposure,
400 self.rng,
401 flagList=task.config.newObjectBadFlags)
402 reliability = self.rng.random(nUnassociatedDiaSources)
403 flux = (self.rng.random(nUnassociatedDiaSources)**2)*100
404 fluxErr = np.sqrt(flux)
405 glint_trail = np.zeros(nUnassociatedDiaSources, dtype=bool)
406 glint_trail[12:16] = True # add 4 glint trail sources
407 diaSources["reliability"] = reliability
408 diaSources[config.newObjectFluxField] = flux
409 diaSources[config.newObjectFluxField + "Err"] = fluxErr
410 diaSources["glint_trail"] = glint_trail
411 badFlagName = task.config.newObjectBadFlags[0]
412 badFlags = np.zeros(nUnassociatedDiaSources, dtype=bool)
413 nBadFlags = 20
414 badFlags[0:nBadFlags] = True
415 diaSources[badFlagName] = badFlags
417 def runAndCheckFilter(diaSources, snrThreshold=None, lowReliabilitySnrThreshold=None,
418 reliabilityThreshold=None, lowSnrReliabilityThreshold=None,
419 badFlags=None,
420 ):
422 filterResults = task.filterSources(
423 diaSources.copy(deep=True),
424 snrThreshold=snrThreshold,
425 lowReliabilitySnrThreshold=lowReliabilitySnrThreshold,
426 reliabilityThreshold=reliabilityThreshold,
427 lowSnrReliabilityThreshold=lowSnrReliabilityThreshold,
428 badFlags=badFlags,
429 )
430 self.assertEqual(len(filterResults.goodSources) + len(filterResults.badSources),
431 nUnassociatedDiaSources)
432 goodFlux = filterResults.goodSources[config.newObjectFluxField]
433 goodFluxErr = filterResults.goodSources[config.newObjectFluxField + "Err"]
434 goodSnr = np.array(goodFlux/goodFluxErr)
435 self.assertTrue(np.all(goodSnr > snrThreshold))
436 goodReliability = np.array(filterResults.goodSources["reliability"])
437 self.assertTrue(np.all(goodReliability > reliabilityThreshold))
438 goodLowSnrFlag = goodSnr < lowReliabilitySnrThreshold
439 lowSnrReliability = goodReliability[goodLowSnrFlag]
440 self.assertTrue(np.all(lowSnrReliability > lowSnrReliabilityThreshold))
441 glintTrailSources = np.array(filterResults.goodSources["glint_trail"])
442 self.assertTrue(not any(glintTrailSources))
444 # No sources should be removed if the thresholds are turned off
445 runAndCheckFilter(diaSources,
446 snrThreshold=0, lowReliabilitySnrThreshold=0,
447 reliabilityThreshold=0, lowSnrReliabilityThreshold=0)
448 runAndCheckFilter(diaSources,
449 snrThreshold=0, lowReliabilitySnrThreshold=0,
450 reliabilityThreshold=0, lowSnrReliabilityThreshold=0,
451 badFlags=[badFlagName])
452 runAndCheckFilter(diaSources,
453 snrThreshold=2, lowReliabilitySnrThreshold=8,
454 reliabilityThreshold=0, lowSnrReliabilityThreshold=0)
455 runAndCheckFilter(diaSources,
456 snrThreshold=2, lowReliabilitySnrThreshold=8,
457 reliabilityThreshold=0, lowSnrReliabilityThreshold=0.5)
458 runAndCheckFilter(diaSources,
459 snrThreshold=0, lowReliabilitySnrThreshold=0,
460 reliabilityThreshold=0.1, lowSnrReliabilityThreshold=0.5)
461 runAndCheckFilter(diaSources,
462 snrThreshold=2, lowReliabilitySnrThreshold=8,
463 reliabilityThreshold=0.1, lowSnrReliabilityThreshold=0.5)
464 runAndCheckFilter(diaSources,
465 snrThreshold=2, lowReliabilitySnrThreshold=8,
466 reliabilityThreshold=0.1, lowSnrReliabilityThreshold=0.5,
467 badFlags=[badFlagName])
469 def testRunWithForcedMeasurement(self):
470 """Test running association with forced photometry."""
472 reliabilityThreshold = 0.5
473 trailLengthThreshold = 1.0
474 config = self._makeDefaultConfig(config_file=self.config_file.name,
475 doPackageAlerts=False,
476 forcedReliabilityThreshold=reliabilityThreshold,
477 forcedTrailLengthThreshold=trailLengthThreshold)
478 task = DiaPipelineTask(config=config)
479 nDeepObjects = 20
480 nShallowObjects = 20
481 nGoodDiaSourcesDeep = 100
482 nGoodDiaSourcesShallow = nShallowObjects
483 nGoodDiaSources = nGoodDiaSourcesDeep + nGoodDiaSourcesShallow
484 nBadDiaSources = 200
486 # Create diaObjects
487 diaObjectsDeep = makeDiaObjects(nDeepObjects, self.exposure, self.rng, startId=1)
488 diaObjectsShallow = makeDiaObjects(nShallowObjects, self.exposure, self.rng, startId=1 + nDeepObjects)
489 diaObjects = task.mergeCatalogs(diaObjectsDeep, diaObjectsShallow, tableName="DiaObject")
491 diaSourcesGoodDeep = makeDiaSources(nGoodDiaSourcesDeep, diaObjectsDeep["diaObjectId"].to_numpy(),
492 self.exposure, self.rng, startId=1,
493 flagList=config.forcedBadFlags)
494 diaSourcesGoodShallow = makeDiaSources(nGoodDiaSourcesShallow,
495 diaObjectsShallow["diaObjectId"].to_numpy(),
496 self.exposure, self.rng, startId=1 + nGoodDiaSourcesDeep,
497 flagList=config.forcedBadFlags)
499 diaSourcesGoodDeep = convertDataFrameToSdmSchema(task.schema, diaSourcesGoodDeep,
500 tableName="DiaSource", skipIndex=True)
501 diaSourcesGood = task.mergeCatalogs(diaSourcesGoodDeep, diaSourcesGoodShallow, tableName="DiaSource")
503 diaSourcesBad = makeDiaSources(nBadDiaSources, diaObjects["diaObjectId"].to_numpy(), self.exposure,
504 self.rng, randomizeObjects=True, startId=1 + nGoodDiaSources,
505 flagList=config.forcedBadFlags)
506 diaSourcesGood['reliability'] = (self.rng.random(nGoodDiaSources)*(1 - reliabilityThreshold)
507 + reliabilityThreshold)
508 diaSourcesGood['trailLength'] = self.rng.random(nGoodDiaSources)*trailLengthThreshold
509 diaSourcesBad = convertDataFrameToSdmSchema(task.schema, diaSourcesBad,
510 tableName="DiaSource", skipIndex=True)
512 # Set some "bad" diaSources to have bad reliability, some good
513 diaSourcesBad['reliability'] = self.rng.random(nBadDiaSources)
514 # Set some "bad" diaSources to have too long trail lengths, some acceptible
515 diaSourcesBad['trailLength'] = self.rng.random(nBadDiaSources)*trailLengthThreshold + 0.5
516 missingBadFlags = ((diaSourcesBad['reliability'] > reliabilityThreshold)
517 & (diaSourcesBad['trailLength'] < trailLengthThreshold)
518 )
519 for badFlag in config.forcedBadFlags:
520 # Set a fraction of the bad diaSources to have each flag, assigned randomly
521 diaSourcesBad[badFlag] = self.rng.random(nBadDiaSources) > 0.8
522 missingBadFlags &= ~diaSourcesBad[badFlag]
523 # Catch any "bad" diaSources that have not yet been flagged
524 if np.any(missingBadFlags):
525 diaSourcesBad[badFlag] |= missingBadFlags
526 diaObjectsForcedGoodOnly = task._selectGoodDiaObjects(diaObjects, diaSourcesGood)
527 diaObjectsForcedBadOnly = task._selectGoodDiaObjects(diaObjects, diaSourcesBad)
528 # None of the diaObjects should be selected if only given the bad diaSources
529 self.assertTrue(diaObjectsForcedBadOnly.empty)
531 diaSources = task.mergeCatalogs(diaSourcesGood, diaSourcesBad, tableName="DiaSource")
532 diaObjectsForced = task._selectGoodDiaObjects(diaObjects, diaSources)
533 # The number of diaObjects selected should be the same regardless of
534 # whether the bad diaSources are included
535 self.assertEqual(len(diaObjectsForced), len(diaObjectsForcedGoodOnly))
536 # All of the deep diaObjects should be selected, and none of the shallow
537 self.assertEqual(len(diaObjectsForced), nDeepObjects)
538 fSrc = task.runForcedMeasurement(diaObjectsForced, diaObjectsForced, self.exposure, self.exposure,
539 IdGenerator())
540 self.assertEqual(set(fSrc['diaObjectId']), set(diaObjectsDeep['diaObjectId']))
542 def test_selectGoodObjects(self):
543 """Test the diaObject selection funtion used for forced photometry.
544 """
545 reliabilityThreshold = 0.5
546 trailLengthThreshold = 1.0
547 config = self._makeDefaultConfig(config_file=self.config_file.name,
548 doPackageAlerts=False,
549 forcedReliabilityThreshold=reliabilityThreshold,
550 forcedTrailLengthThreshold=trailLengthThreshold)
551 task = DiaPipelineTask(config=config)
552 nObjects = 20
553 nDiaSources = 100
555 # Create diaObjects
556 diaObjects = makeDiaObjects(nObjects, self.exposure, self.rng, startId=1)
558 diaSources = makeDiaSources(nDiaSources, diaObjects["diaObjectId"].to_numpy(),
559 self.exposure, self.rng, startId=1,
560 flagList=config.forcedBadFlags)
561 # Since flagList is not specified, the columns will be missing, and added and filled with NaNs
562 # when put through convertDataFrameToSdmSchema.
563 diaSourcesMissingColumns = makeDiaSources(nDiaSources, diaObjects["diaObjectId"].to_numpy(),
564 self.exposure, self.rng, startId=1 + nDiaSources)
565 # Should raise an error if run before adding the required columns
566 with self.assertRaises(RuntimeError):
567 task._selectGoodDiaObjects(diaObjects, diaSources)
569 diaSources['reliability'] = (self.rng.random(nDiaSources)*(1 - reliabilityThreshold)
570 + reliabilityThreshold)
571 diaSources['trailLength'] = self.rng.random(nDiaSources)*trailLengthThreshold
573 diaSourcesMissingColumns['reliability'] = (self.rng.random(nDiaSources)*(1 - reliabilityThreshold)
574 + reliabilityThreshold)
575 diaSourcesMissingColumns['trailLength'] = self.rng.random(nDiaSources)*trailLengthThreshold
577 diaSourcesNotMatched = makeDiaSources(nDiaSources, diaObjects["diaObjectId"].to_numpy() + 999,
578 self.exposure, self.rng, startId=1)
580 diaSourcesNotMatched['reliability'] = (self.rng.random(nDiaSources)*(1 - reliabilityThreshold)
581 + reliabilityThreshold)
582 diaSourcesNotMatched['trailLength'] = self.rng.random(nDiaSources)*trailLengthThreshold
584 diaSources = convertDataFrameToSdmSchema(task.schema, diaSources,
585 tableName="DiaSource", skipIndex=True)
586 diaSourcesMissingColumns = convertDataFrameToSdmSchema(task.schema, diaSourcesMissingColumns,
587 tableName="DiaSource", skipIndex=True)
588 diaSourcesNotMatched = convertDataFrameToSdmSchema(task.schema, diaSourcesNotMatched,
589 tableName="DiaSource", skipIndex=True)
591 # Run the method with missing flags
592 diaObjectsSelected = task._selectGoodDiaObjects(diaObjects, diaSources)
593 diaObjectsBadSelection = task._selectGoodDiaObjects(diaObjects, diaSourcesNotMatched)
594 diaObjectsNaNSelection = task._selectGoodDiaObjects(diaObjects, diaSourcesMissingColumns)
595 # The matched catalog of good diaSources should not drop any diaObjects
596 self.assertTrue(diaObjects.equals(diaObjectsSelected))
597 # The mis-matched catalog of good diaSources should drop every diaObject
598 self.assertTrue(diaObjectsBadSelection.empty)
599 # All diaObjects should be dropped for the diaSource catalog that has all NaN values for the flags
600 self.assertTrue(diaObjectsNaNSelection.empty)
602 def test_selectGoodObjectsWithWrongFlags(self):
603 """Test the diaObject selection funtion used for forced photometry.
604 """
605 reliabilityThreshold = 0.5
606 trailLengthThreshold = 1.0
607 config = self._makeDefaultConfig(config_file=self.config_file.name,
608 doPackageAlerts=False,
609 forcedReliabilityThreshold=reliabilityThreshold,
610 forcedTrailLengthThreshold=trailLengthThreshold,
611 forcedBadFlags=['foo', 'bar'])
612 task = DiaPipelineTask(config=config)
613 nObjects = 20
614 nDiaSources = 100
616 # Create diaObjects
617 diaObjects = makeDiaObjects(nObjects, self.exposure, self.rng, startId=1)
619 diaSources = makeDiaSources(nDiaSources, diaObjects["diaObjectId"].to_numpy(),
620 self.exposure, self.rng, startId=1)
622 diaSources['reliability'] = (self.rng.random(nDiaSources)*(1 - reliabilityThreshold)
623 + reliabilityThreshold)
624 diaSources['trailLength'] = self.rng.random(nDiaSources)*trailLengthThreshold
626 diaSourcesNotMatched = makeDiaSources(nDiaSources, diaObjects["diaObjectId"].to_numpy() + 999,
627 self.exposure, self.rng, startId=1)
629 diaSourcesNotMatched['reliability'] = (self.rng.random(nDiaSources)*(1 - reliabilityThreshold)
630 + reliabilityThreshold)
631 diaSourcesNotMatched['trailLength'] = self.rng.random(nDiaSources)*trailLengthThreshold
633 diaSources = convertDataFrameToSdmSchema(task.schema, diaSources,
634 tableName="DiaSource", skipIndex=True)
635 diaSourcesNotMatched = convertDataFrameToSdmSchema(task.schema, diaSourcesNotMatched,
636 tableName="DiaSource", skipIndex=True)
638 # Verify that the flags are in fact missing
639 for flag in config.forcedBadFlags:
640 self.assertNotIn(flag, diaSources.columns)
641 # Run the method with missing flags
642 diaObjectsSelected = task._selectGoodDiaObjects(diaObjects, diaSources)
643 diaObjectsBadSelection = task._selectGoodDiaObjects(diaObjects, diaSourcesNotMatched)
644 # The matched catalog of good diaSources should not drop any diaObjects
645 self.assertTrue(diaObjects.equals(diaObjectsSelected))
646 # The mis-matched catalog of good diaSources should drop every diaObject
647 self.assertTrue(diaObjectsBadSelection.empty)
649 def test_mergeEmptyCatalog(self):
650 """Test that a catalog is unchanged if it is merged with an empty
651 catalog.
652 """
653 diaSourcesBase = self.diaSources
655 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False)
656 task = DiaPipelineTask(config=config)
657 # Include some but not all columns that should be in diaSourcesBase, and some that are mis-matched
658 diaSourcesEmpty = pd.DataFrame(columns=["ra", "dec", "foo"])
659 diaSourcesTest = task.mergeCatalogs(diaSourcesBase, diaSourcesEmpty, tableName="DiaSource")
660 self.assertTrue(diaSourcesBase.equals(diaSourcesTest))
662 def test_mergeCatalogs(self):
663 """Test that a merged catalog is concatenated correctly.
664 """
665 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False)
666 task = DiaPipelineTask(config=config)
668 diaSourcesBase = convertDataFrameToSdmSchema(task.schema, self.diaSources, "DiaSource",
669 skipIndex=True)
670 nBase = len(diaSourcesBase)
671 nNew = int(nBase/2)
673 diaSourcesNew = makeDiaSources(nNew, self.diaObjects["diaObjectId"].to_numpy(), self.exposure,
674 self.rng)
675 diaSourcesNew = convertDataFrameToSdmSchema(task.schema, diaSourcesNew, "DiaSource", skipIndex=True)
676 diaSourcesTest = task.mergeCatalogs(diaSourcesBase, diaSourcesNew, tableName="DiaSource")
677 self.assertEqual(len(diaSourcesTest), nBase + nNew)
678 diaSourcesExtract1 = diaSourcesTest.iloc[:nBase]
679 diaSourcesExtract2 = diaSourcesTest.iloc[nBase:]
681 pd.testing.assert_frame_equal(diaSourcesBase, diaSourcesExtract1)
682 pd.testing.assert_frame_equal(diaSourcesNew, diaSourcesExtract2)
684 def test_updateObjectTable(self):
685 """Test that the diaObject record is updated with the number of
686 diaSources.
687 """
688 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False)
689 task = DiaPipelineTask(config=config)
690 nObjects = 20
691 nSrcPerObject = 10
692 nExtraSources = 5
693 nSources = nSrcPerObject*nObjects + nExtraSources
694 expectedSourcesPerObject = nSrcPerObject*np.ones(nObjects)
695 expectedSourcesPerObject[:nExtraSources] += 1
696 diaObjects = makeDiaObjects(nObjects, self.exposure, self.rng)
697 diaSources = makeDiaSources(nSources, diaObjects["diaObjectId"].to_numpy(), self.exposure, self.rng)
698 updatedDiaObjects = task.updateObjectTable(diaObjects, diaSources)
699 self.assertTrue(np.all(updatedDiaObjects.nDiaSources.values == expectedSourcesPerObject))
702class MemoryTester(lsst.utils.tests.MemoryTestCase):
703 pass
706def setup_module(module):
707 lsst.utils.tests.init()
710if __name__ == "__main__": 710 ↛ 711line 710 didn't jump to line 711 because the condition on line 710 was never true
711 lsst.utils.tests.init()
712 unittest.main()