Coverage for tests / test_diaPipe.py: 13%

350 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 08:31 +0000

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import contextlib 

23import tempfile 

24import unittest 

25from unittest.mock import patch, MagicMock, DEFAULT 

26import warnings 

27 

28import numpy as np 

29import pandas as pd 

30import astropy.table as tb 

31 

32import lsst.afw.table as afwTable 

33import lsst.dax.apdb as daxApdb 

34from lsst.meas.base import IdGenerator 

35import lsst.pex.config as pexConfig 

36import lsst.pipe.base as pipeBase 

37import lsst.utils.tests 

38from lsst.pipe.base.testUtils import assertValidOutput 

39 

40from lsst.ap.association import DiaPipelineTask 

41from lsst.pipe.tasks.schemaUtils import convertDataFrameToSdmSchema 

42from utils_tests import makeExposure, makeDiaObjects, makeDiaSources, makeDiaForcedSources, \ 

43 makeSolarSystemSources 

44 

45 

46def _makeMockDataFrame(): 

47 """Create a new mock of a DataFrame. 

48 

49 Returns 

50 ------- 

51 mock : `unittest.mock.Mock` 

52 A mock guaranteed to accept all operations used by `pandas.DataFrame`. 

53 """ 

54 with warnings.catch_warnings(): 

55 # spec triggers deprecation warnings on DataFrame, but will 

56 # automatically adapt to any removals. 

57 warnings.simplefilter("ignore", category=DeprecationWarning) 

58 return MagicMock(spec=pd.DataFrame()) 

59 

60 

61def _makeMockTable(): 

62 """Create a new mock of a Table. 

63 

64 Returns 

65 ------- 

66 mock : `unittest.mock.Mock` 

67 A mock guaranteed to accept all operations used by `astropy.table.Table`. 

68 """ 

69 with warnings.catch_warnings(): 

70 # spec triggers deprecation warnings on DataFrame, but will 

71 # automatically adapt to any removals. 

72 warnings.simplefilter("ignore", category=DeprecationWarning) 

73 return MagicMock(spec=tb.Table()) 

74 

75 

76class TestDiaPipelineTask(unittest.TestCase): 

77 

78 @classmethod 

79 def _makeDefaultConfig(cls, config_file, **kwargs): 

80 config = DiaPipelineTask.ConfigClass() 

81 config.apdb_config_url = config_file 

82 config.update(**kwargs) 

83 return config 

84 

85 def setUp(self): 

86 # Create an instance of random generator with fixed seed. 

87 rng = np.random.default_rng(1234) 

88 self.rng = rng 

89 

90 # schemas are persisted in both Gen 2 and Gen 3 butler as prototypical catalogs 

91 srcSchema = afwTable.SourceTable.makeMinimalSchema() 

92 srcSchema.addField("base_PixelFlags_flag", type="Flag") 

93 srcSchema.addField("base_PixelFlags_flag_offimage", type="Flag") 

94 self.srcSchema = afwTable.SourceCatalog(srcSchema) 

95 self.exposure = makeExposure(False, False) 

96 self.diffim = makeExposure(False, False) 

97 self.template = makeExposure(False, False) 

98 self.diaObjects = makeDiaObjects(20, self.exposure, rng) 

99 self.diaSources = makeDiaSources( 

100 100, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng) 

101 self.diaForcedSources = makeDiaForcedSources( 

102 200, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng) 

103 self.ssSources = makeSolarSystemSources( 

104 20, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng) 

105 

106 sqlite_file = tempfile.NamedTemporaryFile() 

107 self.addCleanup(sqlite_file.close) 

108 self.config_file = tempfile.NamedTemporaryFile() 

109 self.addCleanup(self.config_file.close) 

110 apdb_config = daxApdb.ApdbSql.init_database(db_url=f"sqlite:///{sqlite_file.name}") 

111 apdb_config.save(self.config_file.name) 

112 

113 def testRun(self): 

114 """Test running while creating and packaging alerts. 

115 """ 

116 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=True, doReloadDiaObjects=False) 

117 

118 def testRunWithSolarSystemAssociation(self): 

119 """Test running while creating and packaging alerts. 

120 """ 

121 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=True, doReloadDiaObjects=False) 

122 

123 def testRunWithAlerts(self): 

124 """Test running while creating and packaging alerts. 

125 """ 

126 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=False, doReloadDiaObjects=False) 

127 

128 def testRunWithoutAlertsOrSolarSystem(self): 

129 """Test running without creating and packaging alerts. 

130 """ 

131 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=False, doReloadDiaObjects=False) 

132 

133 def testRunWithReload(self): 

134 """Test running with reloading DiaObjects. 

135 """ 

136 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=False, doReloadDiaObjects=True) 

137 

138 def testRunWithReloadAndSolarSystem(self): 

139 """Test running with solar system association and reloading DiaObjects. 

140 """ 

141 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=True, doReloadDiaObjects=True) 

142 

143 def testRunWithReloadAndAlerts(self): 

144 """Test running with reloading DiaObjects while creating and packaging alerts. 

145 """ 

146 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=False, doReloadDiaObjects=True) 

147 

148 def testRunDisableDeprecatedDoRunForcedMeasurement(self): 

149 """Test running with forced sources disabled. 

150 """ 

151 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=False, doReloadDiaObjects=True, 

152 doRunForcedMeasurement=False, subtasksToMock=["diaCalculation", ] 

153 ) 

154 

155 def _testRun(self, doPackageAlerts=False, doSolarSystemAssociation=False, 

156 doReloadDiaObjects=False, subtasksToMock=None, **kwargs): 

157 """Test the normal workflow of each ap_pipe step. 

158 """ 

159 config = self._makeDefaultConfig( 

160 config_file=self.config_file.name, 

161 doPackageAlerts=doPackageAlerts, 

162 doSolarSystemAssociation=doSolarSystemAssociation, 

163 doReloadDiaObjects=doReloadDiaObjects, 

164 **kwargs 

165 ) 

166 task = DiaPipelineTask(config=config) 

167 # Set DataFrame index testing to always return False. Mocks return 

168 # true for this check otherwise. 

169 task.testDataFrameIndex = lambda x: False 

170 diaSrc = _makeMockDataFrame() 

171 ssObjects = _makeMockTable() 

172 

173 # Each of these subtasks should be called once during diaPipe 

174 # execution. We use mocks here to check they are being executed 

175 # appropriately. 

176 if subtasksToMock is None: 

177 subtasksToMock = ["diaCalculation", "diaForcedSource", ] 

178 if doPackageAlerts: 

179 subtasksToMock.append("alertPackager") 

180 else: 

181 self.assertFalse(hasattr(task, "alertPackager")) 

182 

183 if not doSolarSystemAssociation: 

184 self.assertFalse(hasattr(task, "solarSystemAssociator")) 

185 

186 def concatMock(_data, **_kwargs): 

187 return _makeMockDataFrame() 

188 

189 # Mock out the run() methods of these Tasks to ensure they 

190 # return data in the correct form. 

191 def solarSystemAssociator_run(unAssocDiaSources, solarSystemObjectTable, visitInfo, 

192 bbox, wcs): 

193 return lsst.pipe.base.Struct(nTotalSsObjects=42, 

194 nAssociatedSsObjects=30, 

195 ssoAssocDiaSources=_makeMockTable(), 

196 unAssocDiaSources=_makeMockTable(), 

197 associatedSsSources=_makeMockTable(), 

198 unassociatedSsObjects=_makeMockTable()) 

199 

200 def associator_run(table, diaObjects): 

201 return lsst.pipe.base.Struct(nUpdatedDiaObjects=2, nUnassociatedDiaObjects=3, 

202 matchedDiaSources=_makeMockDataFrame(), 

203 unAssocDiaSources=_makeMockDataFrame()) 

204 

205 def loadObjects_run(region, preloadedDiaObjects): 

206 task.metadata['loadRefreshedDiaObjectsStartUtc'] = 1.234 

207 task.metadata['loadRefreshedDiaObjectsEndUtc'] = 5.678 

208 return self.diaObjects 

209 

210 def updateObjectTableMock(diaObjects, diaSources): 

211 pass 

212 

213 def _selectGoodDiaObjects(diaObjectCat, mergedDiaSourceHistory): 

214 return diaObjectCat.copy(deep=True) 

215 

216 # apdb isn't a subtask, but still needs to be mocked out for correct 

217 # execution in the test environment. 

218 with patch.multiple(task, **{task: DEFAULT for task in subtasksToMock + ["apdb"]}), \ 

219 patch('lsst.ap.association.diaPipe.pd.concat', side_effect=concatMock), \ 

220 patch('lsst.ap.association.diaPipe.DiaPipelineTask.updateObjectTable', 

221 side_effect=updateObjectTableMock), \ 

222 patch('lsst.ap.association.diaPipe.DiaPipelineTask._selectGoodDiaObjects', 

223 side_effect=_selectGoodDiaObjects), \ 

224 patch('lsst.ap.association.diaPipe.DiaPipelineTask.loadRefreshedDiaObjects', 

225 side_effect=loadObjects_run), \ 

226 patch('lsst.ap.association.association.AssociationTask.run', 

227 side_effect=associator_run) as mainRun, \ 

228 patch('lsst.pipe.tasks.ssoAssociation.SolarSystemAssociationTask.run', 

229 side_effect=solarSystemAssociator_run) as ssRun: 

230 

231 result = task.run(diaSrc, 

232 None, 

233 self.diffim, 

234 self.exposure, 

235 self.template, 

236 preloadedDiaObjects=self.diaObjects, 

237 preloadedDiaSources=self.diaSources, 

238 preloadedDiaForcedSources=self.diaForcedSources, 

239 band="g", 

240 idGenerator=IdGenerator(), 

241 solarSystemObjectTable=ssObjects) 

242 for subtaskName in subtasksToMock: 

243 getattr(task, subtaskName).run.assert_called_once() 

244 assertValidOutput(task, result) 

245 # Exact type and contents of apdbMarker are undefined. 

246 self.assertIsInstance(result.apdbMarker, pexConfig.Config) 

247 meta = task.getFullMetadata() 

248 # Check that the expected metadata has been set. 

249 self.assertEqual(meta["diaPipe.numUpdatedDiaObjects"], 2) 

250 self.assertEqual(meta["diaPipe.numUnassociatedDiaObjects"], 3) 

251 # and that associators ran once or not at all. 

252 mainRun.assert_called_once() 

253 if doSolarSystemAssociation: 

254 ssRun.assert_called_once() 

255 else: 

256 ssRun.assert_not_called() 

257 

258 def test_tooManyDiaObjectsError(self): 

259 maxNewDiaObjects = 100 

260 

261 nDiaSources = maxNewDiaObjects + 1 

262 diaSources = makeDiaSources(nDiaSources, np.zeros(nDiaSources), self.exposure, self.rng) 

263 

264 def runAndTestWithContextManager(threshold): 

265 config = self._makeDefaultConfig(config_file=self.config_file.name, 

266 doSolarSystemAssociation=False, 

267 filterUnAssociatedSources=False, 

268 maxNewDiaObjects=threshold, 

269 ) 

270 task = DiaPipelineTask(config=config) 

271 contextManager = self.assertRaises(pipeBase.AlgorithmError) if nDiaSources > threshold > 0 \ 

272 else contextlib.nullcontext() 

273 with contextManager: 

274 task.associateDiaSources( 

275 diaSources, 

276 None, 

277 None, 

278 self.diaObjects, 

279 ) 

280 # Test cases at, above, and below the threshold as well as at 0. 

281 runAndTestWithContextManager(0) 

282 runAndTestWithContextManager(maxNewDiaObjects - 1) 

283 runAndTestWithContextManager(maxNewDiaObjects) 

284 runAndTestWithContextManager(maxNewDiaObjects + 1) 

285 

286 def test_createDiaObjects(self): 

287 """Test that creating new DiaObjects works as expected. 

288 """ 

289 nSources = 5 

290 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False) 

291 task = DiaPipelineTask(config=config) 

292 diaSources = pd.DataFrame(data=[ 

293 {"ra": 0.04*idx, "dec": 0.04*idx, 

294 "diaSourceId": idx + 1 + nSources, "diaObjectId": 0, 

295 "ssObjectId": 0} 

296 for idx in range(nSources)]) 

297 

298 result = task.createNewDiaObjects(convertDataFrameToSdmSchema(task.schema, diaSources, "DiaSource", 

299 skipIndex=True)) 

300 self.assertEqual(nSources, len(result.newDiaObjects)) 

301 self.assertTrue(np.all(np.equal( 

302 result.diaSources["diaObjectId"].to_numpy(), 

303 result.diaSources["diaSourceId"].to_numpy()))) 

304 self.assertTrue(np.all(np.equal( 

305 result.newDiaObjects["diaObjectId"].to_numpy(), 

306 result.diaSources["diaSourceId"].to_numpy()))) 

307 

308 def test_purgeDiaObjects(self): 

309 """Remove diaOjects that are outside an image's bounding box. 

310 """ 

311 

312 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False) 

313 task = DiaPipelineTask(config=config) 

314 exposure = makeExposure(False, False) 

315 nObj0 = 20 

316 

317 # Create diaObjects 

318 diaObjects = makeDiaObjects(nObj0, exposure, self.rng) 

319 # Shrink the bounding box so that some of the diaObjects will be outside 

320 bbox = exposure.getBBox() 

321 size = np.minimum(bbox.getHeight(), bbox.getWidth()) 

322 bbox.grow(-size//4) 

323 exposureCut = exposure[bbox] 

324 sizeCut = np.minimum(bbox.getHeight(), bbox.getWidth()) 

325 buffer = 10 

326 bbox.grow(buffer) 

327 

328 def check_diaObjects(bbox, wcs, diaObjects): 

329 raVals = diaObjects.ra.to_numpy() 

330 decVals = diaObjects.dec.to_numpy() 

331 xVals, yVals = wcs.skyToPixelArray(raVals, decVals, degrees=True) 

332 selector = bbox.contains(xVals, yVals) 

333 return selector 

334 

335 selector0 = check_diaObjects(bbox, exposureCut.getWcs(), diaObjects) 

336 nIn0 = np.count_nonzero(selector0) 

337 nOut0 = np.count_nonzero(~selector0) 

338 self.assertEqual(nObj0, nIn0 + nOut0) 

339 

340 # Add an ID that is not in the diaObject table. It should not get removed. 

341 diaObjectIds0 = diaObjects["diaObjectId"].copy(deep=True) 

342 diaObjectIds0[max(diaObjectIds0.index) + 1] = 999 

343 diaObjects1, objIds = task.purgeDiaObjects(exposureCut.getBBox(), exposureCut.getWcs(), diaObjects, 

344 diaObjectIds=diaObjectIds0, buffer=buffer) 

345 diaObjectIds1 = diaObjects1["diaObjectId"] 

346 # Verify that the bounding box was not changed 

347 sizeCheck = np.minimum(exposureCut.getBBox().getHeight(), exposureCut.getBBox().getWidth()) 

348 self.assertEqual(sizeCut, sizeCheck) 

349 selector1 = check_diaObjects(bbox, exposureCut.getWcs(), diaObjects1) 

350 nIn1 = np.count_nonzero(selector1) 

351 nOut1 = np.count_nonzero(~selector1) 

352 nObj1 = len(diaObjects1) 

353 self.assertEqual(nObj1, nIn0) 

354 # Verify that not all diaObjects were removed 

355 self.assertGreater(nObj1, 0) 

356 # Check that some diaObjects were removed 

357 self.assertLess(nObj1, nObj0) 

358 # Verify that no objects outside the bounding box remain 

359 self.assertEqual(nOut1, 0) 

360 # Verify that no objects inside the bounding box were removed 

361 self.assertEqual(nIn1, nIn0) 

362 # The length of the updated object IDs should equal the number of objects 

363 # plus one, since we added an extra ID. 

364 self.assertEqual(nObj1 + 1, len(objIds)) 

365 # All of the object IDs extracted from the catalog should be in the pruned object IDs 

366 self.assertTrue(set(objIds).issuperset(diaObjectIds1)) 

367 # The pruned object IDs should contain entries that are not in the catalog 

368 self.assertFalse(set(diaObjectIds1).issuperset(objIds)) 

369 # Some IDs should have been removed 

370 self.assertLess(len(objIds), len(diaObjectIds0)) 

371 

372 def test_filterDiaObjects(self): 

373 """Unassociated diaSources that are filtered should have good reliability and SNR. 

374 Glint trail sources should also be filtered out. 

375 """ 

376 

377 config = self._makeDefaultConfig(config_file=self.config_file.name, 

378 doPackageAlerts=False, 

379 filterUnAssociatedSources=True) 

380 

381 configBadFilter = self._makeDefaultConfig(config_file=self.config_file.name, 

382 doPackageAlerts=False, 

383 filterUnAssociatedSources=True, 

384 newObjectFluxField="notAFlux") 

385 configBadFlag = self._makeDefaultConfig(config_file=self.config_file.name, 

386 doPackageAlerts=False, 

387 filterUnAssociatedSources=True, 

388 newObjectBadFlags=("junkSource", "notUsed")) 

389 with self.assertRaises(pipeBase.InvalidQuantumError): 

390 DiaPipelineTask(config=configBadFilter) 

391 with self.assertRaises(pipeBase.InvalidQuantumError): 

392 DiaPipelineTask(config=configBadFlag) 

393 task = DiaPipelineTask(config=config) 

394 nUnassociatedDiaSources = 234 

395 

396 # Create diaSources 

397 diaSources = makeDiaSources(nUnassociatedDiaSources, 

398 np.zeros(nUnassociatedDiaSources), 

399 self.exposure, 

400 self.rng, 

401 flagList=task.config.newObjectBadFlags) 

402 reliability = self.rng.random(nUnassociatedDiaSources) 

403 flux = (self.rng.random(nUnassociatedDiaSources)**2)*100 

404 fluxErr = np.sqrt(flux) 

405 glint_trail = np.zeros(nUnassociatedDiaSources, dtype=bool) 

406 glint_trail[12:16] = True # add 4 glint trail sources 

407 diaSources["reliability"] = reliability 

408 diaSources[config.newObjectFluxField] = flux 

409 diaSources[config.newObjectFluxField + "Err"] = fluxErr 

410 diaSources["glint_trail"] = glint_trail 

411 badFlagName = task.config.newObjectBadFlags[0] 

412 badFlags = np.zeros(nUnassociatedDiaSources, dtype=bool) 

413 nBadFlags = 20 

414 badFlags[0:nBadFlags] = True 

415 diaSources[badFlagName] = badFlags 

416 

417 def runAndCheckFilter(diaSources, snrThreshold=None, lowReliabilitySnrThreshold=None, 

418 reliabilityThreshold=None, lowSnrReliabilityThreshold=None, 

419 badFlags=None, 

420 ): 

421 

422 filterResults = task.filterSources( 

423 diaSources.copy(deep=True), 

424 snrThreshold=snrThreshold, 

425 lowReliabilitySnrThreshold=lowReliabilitySnrThreshold, 

426 reliabilityThreshold=reliabilityThreshold, 

427 lowSnrReliabilityThreshold=lowSnrReliabilityThreshold, 

428 badFlags=badFlags, 

429 ) 

430 self.assertEqual(len(filterResults.goodSources) + len(filterResults.badSources), 

431 nUnassociatedDiaSources) 

432 goodFlux = filterResults.goodSources[config.newObjectFluxField] 

433 goodFluxErr = filterResults.goodSources[config.newObjectFluxField + "Err"] 

434 goodSnr = np.array(goodFlux/goodFluxErr) 

435 self.assertTrue(np.all(goodSnr > snrThreshold)) 

436 goodReliability = np.array(filterResults.goodSources["reliability"]) 

437 self.assertTrue(np.all(goodReliability > reliabilityThreshold)) 

438 goodLowSnrFlag = goodSnr < lowReliabilitySnrThreshold 

439 lowSnrReliability = goodReliability[goodLowSnrFlag] 

440 self.assertTrue(np.all(lowSnrReliability > lowSnrReliabilityThreshold)) 

441 glintTrailSources = np.array(filterResults.goodSources["glint_trail"]) 

442 self.assertTrue(not any(glintTrailSources)) 

443 

444 # No sources should be removed if the thresholds are turned off 

445 runAndCheckFilter(diaSources, 

446 snrThreshold=0, lowReliabilitySnrThreshold=0, 

447 reliabilityThreshold=0, lowSnrReliabilityThreshold=0) 

448 runAndCheckFilter(diaSources, 

449 snrThreshold=0, lowReliabilitySnrThreshold=0, 

450 reliabilityThreshold=0, lowSnrReliabilityThreshold=0, 

451 badFlags=[badFlagName]) 

452 runAndCheckFilter(diaSources, 

453 snrThreshold=2, lowReliabilitySnrThreshold=8, 

454 reliabilityThreshold=0, lowSnrReliabilityThreshold=0) 

455 runAndCheckFilter(diaSources, 

456 snrThreshold=2, lowReliabilitySnrThreshold=8, 

457 reliabilityThreshold=0, lowSnrReliabilityThreshold=0.5) 

458 runAndCheckFilter(diaSources, 

459 snrThreshold=0, lowReliabilitySnrThreshold=0, 

460 reliabilityThreshold=0.1, lowSnrReliabilityThreshold=0.5) 

461 runAndCheckFilter(diaSources, 

462 snrThreshold=2, lowReliabilitySnrThreshold=8, 

463 reliabilityThreshold=0.1, lowSnrReliabilityThreshold=0.5) 

464 runAndCheckFilter(diaSources, 

465 snrThreshold=2, lowReliabilitySnrThreshold=8, 

466 reliabilityThreshold=0.1, lowSnrReliabilityThreshold=0.5, 

467 badFlags=[badFlagName]) 

468 

469 def testRunWithForcedMeasurement(self): 

470 """Test running association with forced photometry.""" 

471 

472 reliabilityThreshold = 0.5 

473 trailLengthThreshold = 1.0 

474 config = self._makeDefaultConfig(config_file=self.config_file.name, 

475 doPackageAlerts=False, 

476 forcedReliabilityThreshold=reliabilityThreshold, 

477 forcedTrailLengthThreshold=trailLengthThreshold) 

478 task = DiaPipelineTask(config=config) 

479 nDeepObjects = 20 

480 nShallowObjects = 20 

481 nGoodDiaSourcesDeep = 100 

482 nGoodDiaSourcesShallow = nShallowObjects 

483 nGoodDiaSources = nGoodDiaSourcesDeep + nGoodDiaSourcesShallow 

484 nBadDiaSources = 200 

485 

486 # Create diaObjects 

487 diaObjectsDeep = makeDiaObjects(nDeepObjects, self.exposure, self.rng, startId=1) 

488 diaObjectsShallow = makeDiaObjects(nShallowObjects, self.exposure, self.rng, startId=1 + nDeepObjects) 

489 diaObjects = task.mergeCatalogs(diaObjectsDeep, diaObjectsShallow, tableName="DiaObject") 

490 

491 diaSourcesGoodDeep = makeDiaSources(nGoodDiaSourcesDeep, diaObjectsDeep["diaObjectId"].to_numpy(), 

492 self.exposure, self.rng, startId=1, 

493 flagList=config.forcedBadFlags) 

494 diaSourcesGoodShallow = makeDiaSources(nGoodDiaSourcesShallow, 

495 diaObjectsShallow["diaObjectId"].to_numpy(), 

496 self.exposure, self.rng, startId=1 + nGoodDiaSourcesDeep, 

497 flagList=config.forcedBadFlags) 

498 

499 diaSourcesGoodDeep = convertDataFrameToSdmSchema(task.schema, diaSourcesGoodDeep, 

500 tableName="DiaSource", skipIndex=True) 

501 diaSourcesGood = task.mergeCatalogs(diaSourcesGoodDeep, diaSourcesGoodShallow, tableName="DiaSource") 

502 

503 diaSourcesBad = makeDiaSources(nBadDiaSources, diaObjects["diaObjectId"].to_numpy(), self.exposure, 

504 self.rng, randomizeObjects=True, startId=1 + nGoodDiaSources, 

505 flagList=config.forcedBadFlags) 

506 diaSourcesGood['reliability'] = (self.rng.random(nGoodDiaSources)*(1 - reliabilityThreshold) 

507 + reliabilityThreshold) 

508 diaSourcesGood['trailLength'] = self.rng.random(nGoodDiaSources)*trailLengthThreshold 

509 diaSourcesBad = convertDataFrameToSdmSchema(task.schema, diaSourcesBad, 

510 tableName="DiaSource", skipIndex=True) 

511 

512 # Set some "bad" diaSources to have bad reliability, some good 

513 diaSourcesBad['reliability'] = self.rng.random(nBadDiaSources) 

514 # Set some "bad" diaSources to have too long trail lengths, some acceptible 

515 diaSourcesBad['trailLength'] = self.rng.random(nBadDiaSources)*trailLengthThreshold + 0.5 

516 missingBadFlags = ((diaSourcesBad['reliability'] > reliabilityThreshold) 

517 & (diaSourcesBad['trailLength'] < trailLengthThreshold) 

518 ) 

519 for badFlag in config.forcedBadFlags: 

520 # Set a fraction of the bad diaSources to have each flag, assigned randomly 

521 diaSourcesBad[badFlag] = self.rng.random(nBadDiaSources) > 0.8 

522 missingBadFlags &= ~diaSourcesBad[badFlag] 

523 # Catch any "bad" diaSources that have not yet been flagged 

524 if np.any(missingBadFlags): 

525 diaSourcesBad[badFlag] |= missingBadFlags 

526 diaObjectsForcedGoodOnly = task._selectGoodDiaObjects(diaObjects, diaSourcesGood) 

527 diaObjectsForcedBadOnly = task._selectGoodDiaObjects(diaObjects, diaSourcesBad) 

528 # None of the diaObjects should be selected if only given the bad diaSources 

529 self.assertTrue(diaObjectsForcedBadOnly.empty) 

530 

531 diaSources = task.mergeCatalogs(diaSourcesGood, diaSourcesBad, tableName="DiaSource") 

532 diaObjectsForced = task._selectGoodDiaObjects(diaObjects, diaSources) 

533 # The number of diaObjects selected should be the same regardless of 

534 # whether the bad diaSources are included 

535 self.assertEqual(len(diaObjectsForced), len(diaObjectsForcedGoodOnly)) 

536 # All of the deep diaObjects should be selected, and none of the shallow 

537 self.assertEqual(len(diaObjectsForced), nDeepObjects) 

538 fSrc = task.runForcedMeasurement(diaObjectsForced, diaObjectsForced, self.exposure, self.exposure, 

539 IdGenerator()) 

540 self.assertEqual(set(fSrc['diaObjectId']), set(diaObjectsDeep['diaObjectId'])) 

541 

542 def test_selectGoodObjects(self): 

543 """Test the diaObject selection funtion used for forced photometry. 

544 """ 

545 reliabilityThreshold = 0.5 

546 trailLengthThreshold = 1.0 

547 config = self._makeDefaultConfig(config_file=self.config_file.name, 

548 doPackageAlerts=False, 

549 forcedReliabilityThreshold=reliabilityThreshold, 

550 forcedTrailLengthThreshold=trailLengthThreshold) 

551 task = DiaPipelineTask(config=config) 

552 nObjects = 20 

553 nDiaSources = 100 

554 

555 # Create diaObjects 

556 diaObjects = makeDiaObjects(nObjects, self.exposure, self.rng, startId=1) 

557 

558 diaSources = makeDiaSources(nDiaSources, diaObjects["diaObjectId"].to_numpy(), 

559 self.exposure, self.rng, startId=1, 

560 flagList=config.forcedBadFlags) 

561 # Since flagList is not specified, the columns will be missing, and added and filled with NaNs 

562 # when put through convertDataFrameToSdmSchema. 

563 diaSourcesMissingColumns = makeDiaSources(nDiaSources, diaObjects["diaObjectId"].to_numpy(), 

564 self.exposure, self.rng, startId=1 + nDiaSources) 

565 # Should raise an error if run before adding the required columns 

566 with self.assertRaises(RuntimeError): 

567 task._selectGoodDiaObjects(diaObjects, diaSources) 

568 

569 diaSources['reliability'] = (self.rng.random(nDiaSources)*(1 - reliabilityThreshold) 

570 + reliabilityThreshold) 

571 diaSources['trailLength'] = self.rng.random(nDiaSources)*trailLengthThreshold 

572 

573 diaSourcesMissingColumns['reliability'] = (self.rng.random(nDiaSources)*(1 - reliabilityThreshold) 

574 + reliabilityThreshold) 

575 diaSourcesMissingColumns['trailLength'] = self.rng.random(nDiaSources)*trailLengthThreshold 

576 

577 diaSourcesNotMatched = makeDiaSources(nDiaSources, diaObjects["diaObjectId"].to_numpy() + 999, 

578 self.exposure, self.rng, startId=1) 

579 

580 diaSourcesNotMatched['reliability'] = (self.rng.random(nDiaSources)*(1 - reliabilityThreshold) 

581 + reliabilityThreshold) 

582 diaSourcesNotMatched['trailLength'] = self.rng.random(nDiaSources)*trailLengthThreshold 

583 

584 diaSources = convertDataFrameToSdmSchema(task.schema, diaSources, 

585 tableName="DiaSource", skipIndex=True) 

586 diaSourcesMissingColumns = convertDataFrameToSdmSchema(task.schema, diaSourcesMissingColumns, 

587 tableName="DiaSource", skipIndex=True) 

588 diaSourcesNotMatched = convertDataFrameToSdmSchema(task.schema, diaSourcesNotMatched, 

589 tableName="DiaSource", skipIndex=True) 

590 

591 # Run the method with missing flags 

592 diaObjectsSelected = task._selectGoodDiaObjects(diaObjects, diaSources) 

593 diaObjectsBadSelection = task._selectGoodDiaObjects(diaObjects, diaSourcesNotMatched) 

594 diaObjectsNaNSelection = task._selectGoodDiaObjects(diaObjects, diaSourcesMissingColumns) 

595 # The matched catalog of good diaSources should not drop any diaObjects 

596 self.assertTrue(diaObjects.equals(diaObjectsSelected)) 

597 # The mis-matched catalog of good diaSources should drop every diaObject 

598 self.assertTrue(diaObjectsBadSelection.empty) 

599 # All diaObjects should be dropped for the diaSource catalog that has all NaN values for the flags 

600 self.assertTrue(diaObjectsNaNSelection.empty) 

601 

602 def test_selectGoodObjectsWithWrongFlags(self): 

603 """Test the diaObject selection funtion used for forced photometry. 

604 """ 

605 reliabilityThreshold = 0.5 

606 trailLengthThreshold = 1.0 

607 config = self._makeDefaultConfig(config_file=self.config_file.name, 

608 doPackageAlerts=False, 

609 forcedReliabilityThreshold=reliabilityThreshold, 

610 forcedTrailLengthThreshold=trailLengthThreshold, 

611 forcedBadFlags=['foo', 'bar']) 

612 task = DiaPipelineTask(config=config) 

613 nObjects = 20 

614 nDiaSources = 100 

615 

616 # Create diaObjects 

617 diaObjects = makeDiaObjects(nObjects, self.exposure, self.rng, startId=1) 

618 

619 diaSources = makeDiaSources(nDiaSources, diaObjects["diaObjectId"].to_numpy(), 

620 self.exposure, self.rng, startId=1) 

621 

622 diaSources['reliability'] = (self.rng.random(nDiaSources)*(1 - reliabilityThreshold) 

623 + reliabilityThreshold) 

624 diaSources['trailLength'] = self.rng.random(nDiaSources)*trailLengthThreshold 

625 

626 diaSourcesNotMatched = makeDiaSources(nDiaSources, diaObjects["diaObjectId"].to_numpy() + 999, 

627 self.exposure, self.rng, startId=1) 

628 

629 diaSourcesNotMatched['reliability'] = (self.rng.random(nDiaSources)*(1 - reliabilityThreshold) 

630 + reliabilityThreshold) 

631 diaSourcesNotMatched['trailLength'] = self.rng.random(nDiaSources)*trailLengthThreshold 

632 

633 diaSources = convertDataFrameToSdmSchema(task.schema, diaSources, 

634 tableName="DiaSource", skipIndex=True) 

635 diaSourcesNotMatched = convertDataFrameToSdmSchema(task.schema, diaSourcesNotMatched, 

636 tableName="DiaSource", skipIndex=True) 

637 

638 # Verify that the flags are in fact missing 

639 for flag in config.forcedBadFlags: 

640 self.assertNotIn(flag, diaSources.columns) 

641 # Run the method with missing flags 

642 diaObjectsSelected = task._selectGoodDiaObjects(diaObjects, diaSources) 

643 diaObjectsBadSelection = task._selectGoodDiaObjects(diaObjects, diaSourcesNotMatched) 

644 # The matched catalog of good diaSources should not drop any diaObjects 

645 self.assertTrue(diaObjects.equals(diaObjectsSelected)) 

646 # The mis-matched catalog of good diaSources should drop every diaObject 

647 self.assertTrue(diaObjectsBadSelection.empty) 

648 

649 def test_mergeEmptyCatalog(self): 

650 """Test that a catalog is unchanged if it is merged with an empty 

651 catalog. 

652 """ 

653 diaSourcesBase = self.diaSources 

654 

655 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False) 

656 task = DiaPipelineTask(config=config) 

657 # Include some but not all columns that should be in diaSourcesBase, and some that are mis-matched 

658 diaSourcesEmpty = pd.DataFrame(columns=["ra", "dec", "foo"]) 

659 diaSourcesTest = task.mergeCatalogs(diaSourcesBase, diaSourcesEmpty, tableName="DiaSource") 

660 self.assertTrue(diaSourcesBase.equals(diaSourcesTest)) 

661 

662 def test_mergeCatalogs(self): 

663 """Test that a merged catalog is concatenated correctly. 

664 """ 

665 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False) 

666 task = DiaPipelineTask(config=config) 

667 

668 diaSourcesBase = convertDataFrameToSdmSchema(task.schema, self.diaSources, "DiaSource", 

669 skipIndex=True) 

670 nBase = len(diaSourcesBase) 

671 nNew = int(nBase/2) 

672 

673 diaSourcesNew = makeDiaSources(nNew, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, 

674 self.rng) 

675 diaSourcesNew = convertDataFrameToSdmSchema(task.schema, diaSourcesNew, "DiaSource", skipIndex=True) 

676 diaSourcesTest = task.mergeCatalogs(diaSourcesBase, diaSourcesNew, tableName="DiaSource") 

677 self.assertEqual(len(diaSourcesTest), nBase + nNew) 

678 diaSourcesExtract1 = diaSourcesTest.iloc[:nBase] 

679 diaSourcesExtract2 = diaSourcesTest.iloc[nBase:] 

680 

681 pd.testing.assert_frame_equal(diaSourcesBase, diaSourcesExtract1) 

682 pd.testing.assert_frame_equal(diaSourcesNew, diaSourcesExtract2) 

683 

684 def test_updateObjectTable(self): 

685 """Test that the diaObject record is updated with the number of 

686 diaSources. 

687 """ 

688 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False) 

689 task = DiaPipelineTask(config=config) 

690 nObjects = 20 

691 nSrcPerObject = 10 

692 nExtraSources = 5 

693 nSources = nSrcPerObject*nObjects + nExtraSources 

694 expectedSourcesPerObject = nSrcPerObject*np.ones(nObjects) 

695 expectedSourcesPerObject[:nExtraSources] += 1 

696 diaObjects = makeDiaObjects(nObjects, self.exposure, self.rng) 

697 diaSources = makeDiaSources(nSources, diaObjects["diaObjectId"].to_numpy(), self.exposure, self.rng) 

698 updatedDiaObjects = task.updateObjectTable(diaObjects, diaSources) 

699 self.assertTrue(np.all(updatedDiaObjects.nDiaSources.values == expectedSourcesPerObject)) 

700 

701 

702class MemoryTester(lsst.utils.tests.MemoryTestCase): 

703 pass 

704 

705 

706def setup_module(module): 

707 lsst.utils.tests.init() 

708 

709 

710if __name__ == "__main__": 710 ↛ 711line 710 didn't jump to line 711 because the condition on line 710 was never true

711 lsst.utils.tests.init() 

712 unittest.main()