Coverage for tests / test_diaPipe.py: 17%

257 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 18:39 +0000

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import contextlib 

23import tempfile 

24import unittest 

25from unittest.mock import patch, Mock, MagicMock, DEFAULT 

26import warnings 

27 

28import numpy as np 

29import pandas as pd 

30import astropy.table as tb 

31 

32import lsst.afw.image as afwImage 

33import lsst.afw.table as afwTable 

34import lsst.dax.apdb as daxApdb 

35from lsst.meas.base import IdGenerator 

36import lsst.pex.config as pexConfig 

37import lsst.pipe.base as pipeBase 

38import lsst.utils.tests 

39from lsst.pipe.base.testUtils import assertValidOutput 

40 

41from lsst.ap.association import DiaPipelineTask 

42from lsst.pipe.tasks.schemaUtils import convertDataFrameToSdmSchema 

43from utils_tests import makeExposure, makeDiaObjects, makeDiaSources, makeDiaForcedSources, \ 

44 makeSolarSystemSources 

45 

46 

47def _makeMockDataFrame(): 

48 """Create a new mock of a DataFrame. 

49 

50 Returns 

51 ------- 

52 mock : `unittest.mock.Mock` 

53 A mock guaranteed to accept all operations used by `pandas.DataFrame`. 

54 """ 

55 with warnings.catch_warnings(): 

56 # spec triggers deprecation warnings on DataFrame, but will 

57 # automatically adapt to any removals. 

58 warnings.simplefilter("ignore", category=DeprecationWarning) 

59 return MagicMock(spec=pd.DataFrame()) 

60 

61 

62def _makeMockTable(): 

63 """Create a new mock of a Table. 

64 

65 Returns 

66 ------- 

67 mock : `unittest.mock.Mock` 

68 A mock guaranteed to accept all operations used by `astropy.table.Table`. 

69 """ 

70 with warnings.catch_warnings(): 

71 # spec triggers deprecation warnings on DataFrame, but will 

72 # automatically adapt to any removals. 

73 warnings.simplefilter("ignore", category=DeprecationWarning) 

74 return MagicMock(spec=tb.Table()) 

75 

76 

77class TestDiaPipelineTask(unittest.TestCase): 

78 

79 @classmethod 

80 def _makeDefaultConfig(cls, config_file, **kwargs): 

81 config = DiaPipelineTask.ConfigClass() 

82 config.apdb_config_url = config_file 

83 config.update(**kwargs) 

84 return config 

85 

86 def setUp(self): 

87 # Create an instance of random generator with fixed seed. 

88 rng = np.random.default_rng(1234) 

89 self.rng = rng 

90 

91 # schemas are persisted in both Gen 2 and Gen 3 butler as prototypical catalogs 

92 srcSchema = afwTable.SourceTable.makeMinimalSchema() 

93 srcSchema.addField("base_PixelFlags_flag", type="Flag") 

94 srcSchema.addField("base_PixelFlags_flag_offimage", type="Flag") 

95 self.srcSchema = afwTable.SourceCatalog(srcSchema) 

96 self.exposure = makeExposure(False, False) 

97 self.diaObjects = makeDiaObjects(20, self.exposure, rng) 

98 self.diaSources = makeDiaSources( 

99 100, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng) 

100 self.diaForcedSources = makeDiaForcedSources( 

101 200, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng) 

102 self.ssSources = makeSolarSystemSources( 

103 20, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng) 

104 

105 sqlite_file = tempfile.NamedTemporaryFile() 

106 self.addCleanup(sqlite_file.close) 

107 self.config_file = tempfile.NamedTemporaryFile() 

108 self.addCleanup(self.config_file.close) 

109 apdb_config = daxApdb.ApdbSql.init_database(db_url=f"sqlite:///{sqlite_file.name}") 

110 apdb_config.save(self.config_file.name) 

111 

112 def testRun(self): 

113 """Test running while creating and packaging alerts. 

114 """ 

115 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=True, doReloadDiaObjects=False) 

116 

117 def testRunWithSolarSystemAssociation(self): 

118 """Test running while creating and packaging alerts. 

119 """ 

120 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=True, doReloadDiaObjects=False) 

121 

122 def testRunWithAlerts(self): 

123 """Test running while creating and packaging alerts. 

124 """ 

125 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=False, doReloadDiaObjects=False) 

126 

127 def testRunWithoutAlertsOrSolarSystem(self): 

128 """Test running without creating and packaging alerts. 

129 """ 

130 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=False, doReloadDiaObjects=False) 

131 

132 def testRunWithReload(self): 

133 """Test running with reloading DiaObjects. 

134 """ 

135 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=False, doReloadDiaObjects=True) 

136 

137 def testRunWithReloadAndSolarSystem(self): 

138 """Test running with solar system association and reloading DiaObjects. 

139 """ 

140 self._testRun(doPackageAlerts=False, doSolarSystemAssociation=True, doReloadDiaObjects=True) 

141 

142 def testRunWithReloadAndAlerts(self): 

143 """Test running with reloading DiaObjects while creating and packaging alerts. 

144 """ 

145 self._testRun(doPackageAlerts=True, doSolarSystemAssociation=False, doReloadDiaObjects=True) 

146 

147 def _testRun(self, doPackageAlerts=False, doSolarSystemAssociation=False, doRunForcedMeasurement=False, 

148 doReloadDiaObjects=False): 

149 """Test the normal workflow of each ap_pipe step. 

150 """ 

151 config = self._makeDefaultConfig( 

152 config_file=self.config_file.name, 

153 doPackageAlerts=doPackageAlerts, 

154 doSolarSystemAssociation=doSolarSystemAssociation, 

155 doRunForcedMeasurement=doRunForcedMeasurement, 

156 doReloadDiaObjects=doReloadDiaObjects, 

157 ) 

158 task = DiaPipelineTask(config=config) 

159 # Set DataFrame index testing to always return False. Mocks return 

160 # true for this check otherwise. 

161 task.testDataFrameIndex = lambda x: False 

162 diffIm = Mock(spec=afwImage.ExposureF) 

163 template = Mock(spec=afwImage.ExposureF) 

164 diaSrc = _makeMockDataFrame() 

165 ssObjects = _makeMockTable() 

166 

167 # Each of these subtasks should be called once during diaPipe 

168 # execution. We use mocks here to check they are being executed 

169 # appropriately. 

170 subtasksToMock = [ 

171 "diaCalculation", 

172 ] 

173 if doPackageAlerts: 

174 subtasksToMock.append("alertPackager") 

175 else: 

176 self.assertFalse(hasattr(task, "alertPackager")) 

177 

178 if not doSolarSystemAssociation: 

179 self.assertFalse(hasattr(task, "solarSystemAssociator")) 

180 

181 def concatMock(_data, **_kwargs): 

182 return _makeMockDataFrame() 

183 

184 # Mock out the run() methods of these Tasks to ensure they 

185 # return data in the correct form. 

186 def solarSystemAssociator_run(unAssocDiaSources, solarSystemObjectTable, visitInfo, 

187 bbox, wcs): 

188 return lsst.pipe.base.Struct(nTotalSsObjects=42, 

189 nAssociatedSsObjects=30, 

190 ssoAssocDiaSources=_makeMockTable(), 

191 unAssocDiaSources=_makeMockTable(), 

192 associatedSsSources=_makeMockTable(), 

193 unassociatedSsObjects=_makeMockTable()) 

194 

195 def associator_run(table, diaObjects): 

196 return lsst.pipe.base.Struct(nUpdatedDiaObjects=2, nUnassociatedDiaObjects=3, 

197 matchedDiaSources=_makeMockDataFrame(), 

198 unAssocDiaSources=_makeMockDataFrame()) 

199 

200 def loadObjects_run(region, preloadedDiaObjects): 

201 task.metadata['loadRefreshedDiaObjectsStartUtc'] = 1.234 

202 task.metadata['loadRefreshedDiaObjectsEndUtc'] = 5.678 

203 return self.diaObjects 

204 

205 def updateObjectTableMock(diaObjects, diaSources): 

206 pass 

207 

208 # apdb isn't a subtask, but still needs to be mocked out for correct 

209 # execution in the test environment. 

210 with patch.multiple(task, **{task: DEFAULT for task in subtasksToMock + ["apdb"]}), \ 

211 patch('lsst.ap.association.diaPipe.pd.concat', side_effect=concatMock), \ 

212 patch('lsst.ap.association.diaPipe.DiaPipelineTask.updateObjectTable', 

213 side_effect=updateObjectTableMock), \ 

214 patch('lsst.ap.association.diaPipe.DiaPipelineTask.loadRefreshedDiaObjects', 

215 side_effect=loadObjects_run), \ 

216 patch('lsst.ap.association.association.AssociationTask.run', 

217 side_effect=associator_run) as mainRun, \ 

218 patch('lsst.pipe.tasks.ssoAssociation.SolarSystemAssociationTask.run', 

219 side_effect=solarSystemAssociator_run) as ssRun: 

220 

221 result = task.run(diaSrc, 

222 None, 

223 diffIm, 

224 self.exposure, 

225 template, 

226 preloadedDiaObjects=self.diaObjects, 

227 preloadedDiaSources=self.diaSources, 

228 preloadedDiaForcedSources=self.diaForcedSources, 

229 band="g", 

230 idGenerator=IdGenerator(), 

231 solarSystemObjectTable=ssObjects) 

232 for subtaskName in subtasksToMock: 

233 getattr(task, subtaskName).run.assert_called_once() 

234 assertValidOutput(task, result) 

235 # Exact type and contents of apdbMarker are undefined. 

236 self.assertIsInstance(result.apdbMarker, pexConfig.Config) 

237 meta = task.getFullMetadata() 

238 # Check that the expected metadata has been set. 

239 self.assertEqual(meta["diaPipe.numUpdatedDiaObjects"], 2) 

240 self.assertEqual(meta["diaPipe.numUnassociatedDiaObjects"], 3) 

241 # and that associators ran once or not at all. 

242 mainRun.assert_called_once() 

243 if doSolarSystemAssociation: 

244 ssRun.assert_called_once() 

245 else: 

246 ssRun.assert_not_called() 

247 

248 def test_tooManyDiaObjectsError(self): 

249 maxNewDiaObjects = 100 

250 

251 nDiaSources = maxNewDiaObjects + 1 

252 diaSources = makeDiaSources(nDiaSources, np.zeros(nDiaSources), self.exposure, self.rng) 

253 

254 def runAndTestWithContextManager(threshold): 

255 config = self._makeDefaultConfig(config_file=self.config_file.name, 

256 doSolarSystemAssociation=False, 

257 filterUnAssociatedSources=False, 

258 maxNewDiaObjects=threshold, 

259 ) 

260 task = DiaPipelineTask(config=config) 

261 contextManager = self.assertRaises(pipeBase.AlgorithmError) if nDiaSources > threshold > 0 \ 

262 else contextlib.nullcontext() 

263 with contextManager: 

264 task.associateDiaSources( 

265 diaSources, 

266 None, 

267 None, 

268 self.diaObjects, 

269 ) 

270 # Test cases at, above, and below the threshold as well as at 0. 

271 runAndTestWithContextManager(0) 

272 runAndTestWithContextManager(maxNewDiaObjects - 1) 

273 runAndTestWithContextManager(maxNewDiaObjects) 

274 runAndTestWithContextManager(maxNewDiaObjects + 1) 

275 

276 def test_createDiaObjects(self): 

277 """Test that creating new DiaObjects works as expected. 

278 """ 

279 nSources = 5 

280 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False) 

281 task = DiaPipelineTask(config=config) 

282 diaSources = pd.DataFrame(data=[ 

283 {"ra": 0.04*idx, "dec": 0.04*idx, 

284 "diaSourceId": idx + 1 + nSources, "diaObjectId": 0, 

285 "ssObjectId": 0} 

286 for idx in range(nSources)]) 

287 

288 result = task.createNewDiaObjects(convertDataFrameToSdmSchema(task.schema, diaSources, "DiaSource", 

289 skipIndex=True)) 

290 self.assertEqual(nSources, len(result.newDiaObjects)) 

291 self.assertTrue(np.all(np.equal( 

292 result.diaSources["diaObjectId"].to_numpy(), 

293 result.diaSources["diaSourceId"].to_numpy()))) 

294 self.assertTrue(np.all(np.equal( 

295 result.newDiaObjects["diaObjectId"].to_numpy(), 

296 result.diaSources["diaSourceId"].to_numpy()))) 

297 

298 def test_purgeDiaObjects(self): 

299 """Remove diaOjects that are outside an image's bounding box. 

300 """ 

301 

302 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False) 

303 task = DiaPipelineTask(config=config) 

304 exposure = makeExposure(False, False) 

305 nObj0 = 20 

306 

307 # Create diaObjects 

308 diaObjects = makeDiaObjects(nObj0, exposure, self.rng) 

309 # Shrink the bounding box so that some of the diaObjects will be outside 

310 bbox = exposure.getBBox() 

311 size = np.minimum(bbox.getHeight(), bbox.getWidth()) 

312 bbox.grow(-size//4) 

313 exposureCut = exposure[bbox] 

314 sizeCut = np.minimum(bbox.getHeight(), bbox.getWidth()) 

315 buffer = 10 

316 bbox.grow(buffer) 

317 

318 def check_diaObjects(bbox, wcs, diaObjects): 

319 raVals = diaObjects.ra.to_numpy() 

320 decVals = diaObjects.dec.to_numpy() 

321 xVals, yVals = wcs.skyToPixelArray(raVals, decVals, degrees=True) 

322 selector = bbox.contains(xVals, yVals) 

323 return selector 

324 

325 selector0 = check_diaObjects(bbox, exposureCut.getWcs(), diaObjects) 

326 nIn0 = np.count_nonzero(selector0) 

327 nOut0 = np.count_nonzero(~selector0) 

328 self.assertEqual(nObj0, nIn0 + nOut0) 

329 

330 # Add an ID that is not in the diaObject table. It should not get removed. 

331 diaObjectIds0 = diaObjects["diaObjectId"].copy(deep=True) 

332 diaObjectIds0[max(diaObjectIds0.index) + 1] = 999 

333 diaObjects1, objIds = task.purgeDiaObjects(exposureCut.getBBox(), exposureCut.getWcs(), diaObjects, 

334 diaObjectIds=diaObjectIds0, buffer=buffer) 

335 diaObjectIds1 = diaObjects1["diaObjectId"] 

336 # Verify that the bounding box was not changed 

337 sizeCheck = np.minimum(exposureCut.getBBox().getHeight(), exposureCut.getBBox().getWidth()) 

338 self.assertEqual(sizeCut, sizeCheck) 

339 selector1 = check_diaObjects(bbox, exposureCut.getWcs(), diaObjects1) 

340 nIn1 = np.count_nonzero(selector1) 

341 nOut1 = np.count_nonzero(~selector1) 

342 nObj1 = len(diaObjects1) 

343 self.assertEqual(nObj1, nIn0) 

344 # Verify that not all diaObjects were removed 

345 self.assertGreater(nObj1, 0) 

346 # Check that some diaObjects were removed 

347 self.assertLess(nObj1, nObj0) 

348 # Verify that no objects outside the bounding box remain 

349 self.assertEqual(nOut1, 0) 

350 # Verify that no objects inside the bounding box were removed 

351 self.assertEqual(nIn1, nIn0) 

352 # The length of the updated object IDs should equal the number of objects 

353 # plus one, since we added an extra ID. 

354 self.assertEqual(nObj1 + 1, len(objIds)) 

355 # All of the object IDs extracted from the catalog should be in the pruned object IDs 

356 self.assertTrue(set(objIds).issuperset(diaObjectIds1)) 

357 # The pruned object IDs should contain entries that are not in the catalog 

358 self.assertFalse(set(diaObjectIds1).issuperset(objIds)) 

359 # Some IDs should have been removed 

360 self.assertLess(len(objIds), len(diaObjectIds0)) 

361 

362 def test_filterDiaObjects(self): 

363 """Unassociated diaSources that are filtered should have good reliability and SNR. 

364 Glint trail sources should also be filtered out. 

365 """ 

366 

367 config = self._makeDefaultConfig(config_file=self.config_file.name, 

368 doPackageAlerts=False, 

369 filterUnAssociatedSources=True) 

370 

371 configBadFilter = self._makeDefaultConfig(config_file=self.config_file.name, 

372 doPackageAlerts=False, 

373 filterUnAssociatedSources=True, 

374 newObjectFluxField="notAFlux") 

375 configBadFlag = self._makeDefaultConfig(config_file=self.config_file.name, 

376 doPackageAlerts=False, 

377 filterUnAssociatedSources=True, 

378 newObjectBadFlags=("junkSource", "notUsed")) 

379 with self.assertRaises(pipeBase.InvalidQuantumError): 

380 DiaPipelineTask(config=configBadFilter) 

381 with self.assertRaises(pipeBase.InvalidQuantumError): 

382 DiaPipelineTask(config=configBadFlag) 

383 task = DiaPipelineTask(config=config) 

384 nUnassociatedDiaSources = 234 

385 

386 # Create diaSources 

387 diaSources = makeDiaSources(nUnassociatedDiaSources, 

388 np.zeros(nUnassociatedDiaSources), 

389 self.exposure, 

390 self.rng, 

391 flagList=task.config.newObjectBadFlags) 

392 reliability = self.rng.random(nUnassociatedDiaSources) 

393 flux = (self.rng.random(nUnassociatedDiaSources)**2)*100 

394 fluxErr = np.sqrt(flux) 

395 glint_trail = np.zeros(nUnassociatedDiaSources, dtype=bool) 

396 glint_trail[12:16] = True # add 4 glint trail sources 

397 diaSources["reliability"] = reliability 

398 diaSources[config.newObjectFluxField] = flux 

399 diaSources[config.newObjectFluxField + "Err"] = fluxErr 

400 diaSources["glint_trail"] = glint_trail 

401 badFlagName = task.config.newObjectBadFlags[0] 

402 badFlags = np.zeros(nUnassociatedDiaSources, dtype=bool) 

403 nBadFlags = 20 

404 badFlags[0:nBadFlags] = True 

405 diaSources[badFlagName] = badFlags 

406 

407 def runAndCheckFilter(diaSources, snrThreshold=None, lowReliabilitySnrThreshold=None, 

408 reliabilityThreshold=None, lowSnrReliabilityThreshold=None, 

409 badFlags=None, 

410 ): 

411 

412 filterResults = task.filterSources( 

413 diaSources.copy(deep=True), 

414 snrThreshold=snrThreshold, 

415 lowReliabilitySnrThreshold=lowReliabilitySnrThreshold, 

416 reliabilityThreshold=reliabilityThreshold, 

417 lowSnrReliabilityThreshold=lowSnrReliabilityThreshold, 

418 badFlags=badFlags, 

419 ) 

420 self.assertEqual(len(filterResults.goodSources) + len(filterResults.badSources), 

421 nUnassociatedDiaSources) 

422 goodFlux = filterResults.goodSources[config.newObjectFluxField] 

423 goodFluxErr = filterResults.goodSources[config.newObjectFluxField + "Err"] 

424 goodSnr = np.array(goodFlux/goodFluxErr) 

425 self.assertTrue(np.all(goodSnr > snrThreshold)) 

426 goodReliability = np.array(filterResults.goodSources["reliability"]) 

427 self.assertTrue(np.all(goodReliability > reliabilityThreshold)) 

428 goodLowSnrFlag = goodSnr < lowReliabilitySnrThreshold 

429 lowSnrReliability = goodReliability[goodLowSnrFlag] 

430 self.assertTrue(np.all(lowSnrReliability > lowSnrReliabilityThreshold)) 

431 glintTrailSources = np.array(filterResults.goodSources["glint_trail"]) 

432 self.assertTrue(not any(glintTrailSources)) 

433 

434 # No sources should be removed if the thresholds are turned off 

435 runAndCheckFilter(diaSources, 

436 snrThreshold=0, lowReliabilitySnrThreshold=0, 

437 reliabilityThreshold=0, lowSnrReliabilityThreshold=0) 

438 runAndCheckFilter(diaSources, 

439 snrThreshold=0, lowReliabilitySnrThreshold=0, 

440 reliabilityThreshold=0, lowSnrReliabilityThreshold=0, 

441 badFlags=[badFlagName]) 

442 runAndCheckFilter(diaSources, 

443 snrThreshold=2, lowReliabilitySnrThreshold=8, 

444 reliabilityThreshold=0, lowSnrReliabilityThreshold=0) 

445 runAndCheckFilter(diaSources, 

446 snrThreshold=2, lowReliabilitySnrThreshold=8, 

447 reliabilityThreshold=0, lowSnrReliabilityThreshold=0.5) 

448 runAndCheckFilter(diaSources, 

449 snrThreshold=0, lowReliabilitySnrThreshold=0, 

450 reliabilityThreshold=0.1, lowSnrReliabilityThreshold=0.5) 

451 runAndCheckFilter(diaSources, 

452 snrThreshold=2, lowReliabilitySnrThreshold=8, 

453 reliabilityThreshold=0.1, lowSnrReliabilityThreshold=0.5) 

454 runAndCheckFilter(diaSources, 

455 snrThreshold=2, lowReliabilitySnrThreshold=8, 

456 reliabilityThreshold=0.1, lowSnrReliabilityThreshold=0.5, 

457 badFlags=[badFlagName]) 

458 

459 def test_mergeEmptyCatalog(self): 

460 """Test that a catalog is unchanged if it is merged with an empty 

461 catalog. 

462 """ 

463 diaSourcesBase = self.diaSources 

464 

465 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False) 

466 task = DiaPipelineTask(config=config) 

467 # Include some but not all columns that should be in diaSourcesBase, and some that are mis-matched 

468 diaSourcesEmpty = pd.DataFrame(columns=["ra", "dec", "foo"]) 

469 diaSourcesTest = task.mergeCatalogs(diaSourcesBase, diaSourcesEmpty, tableName="DiaSource") 

470 self.assertTrue(diaSourcesBase.equals(diaSourcesTest)) 

471 

472 def test_mergeCatalogs(self): 

473 """Test that a merged catalog is concatenated correctly. 

474 """ 

475 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False) 

476 task = DiaPipelineTask(config=config) 

477 

478 diaSourcesBase = convertDataFrameToSdmSchema(task.schema, self.diaSources, "DiaSource", 

479 skipIndex=True) 

480 nBase = len(diaSourcesBase) 

481 nNew = int(nBase/2) 

482 

483 diaSourcesNew = makeDiaSources(nNew, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, 

484 self.rng) 

485 diaSourcesNew = convertDataFrameToSdmSchema(task.schema, diaSourcesNew, "DiaSource", skipIndex=True) 

486 diaSourcesTest = task.mergeCatalogs(diaSourcesBase, diaSourcesNew, tableName="DiaSource") 

487 self.assertEqual(len(diaSourcesTest), nBase + nNew) 

488 diaSourcesExtract1 = diaSourcesTest.iloc[:nBase] 

489 diaSourcesExtract2 = diaSourcesTest.iloc[nBase:] 

490 

491 pd.testing.assert_frame_equal(diaSourcesBase, diaSourcesExtract1) 

492 pd.testing.assert_frame_equal(diaSourcesNew, diaSourcesExtract2) 

493 

494 def test_updateObjectTable(self): 

495 """Test that the diaObject record is updated with the number of 

496 diaSources. 

497 """ 

498 config = self._makeDefaultConfig(config_file=self.config_file.name, doPackageAlerts=False) 

499 task = DiaPipelineTask(config=config) 

500 nObjects = 20 

501 nSrcPerObject = 10 

502 nExtraSources = 5 

503 nSources = nSrcPerObject*nObjects + nExtraSources 

504 expectedSourcesPerObject = nSrcPerObject*np.ones(nObjects) 

505 expectedSourcesPerObject[:nExtraSources] += 1 

506 diaObjects = makeDiaObjects(nObjects, self.exposure, self.rng) 

507 diaSources = makeDiaSources(nSources, diaObjects["diaObjectId"].to_numpy(), self.exposure, self.rng) 

508 updatedDiaObjects = task.updateObjectTable(diaObjects, diaSources) 

509 self.assertTrue(np.all(updatedDiaObjects.nDiaSources.values == expectedSourcesPerObject)) 

510 

511 

512class MemoryTester(lsst.utils.tests.MemoryTestCase): 

513 pass 

514 

515 

516def setup_module(module): 

517 lsst.utils.tests.init() 

518 

519 

520if __name__ == "__main__": 520 ↛ 521line 520 didn't jump to line 521 because the condition on line 520 was never true

521 lsst.utils.tests.init() 

522 unittest.main()