Coverage for tests/test_packageAlerts.py: 27%

260 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-19 02:24 -0700

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import io 

23import os 

24 

25import numpy as np 

26import pandas as pd 

27import shutil 

28import tempfile 

29import unittest 

30from unittest.mock import patch, Mock 

31from astropy import wcs 

32from astropy.nddata import CCDData 

33import fastavro 

34try: 

35 import confluent_kafka 

36 from confluent_kafka import KafkaException 

37except ImportError: 

38 confluent_kafka = None 

39 

40import lsst.alert.packet as alertPack 

41from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask 

42from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

43import lsst.afw.image as afwImage 

44import lsst.daf.base as dafBase 

45from lsst.dax.apdb import Apdb, ApdbSql, ApdbSqlConfig 

46import lsst.geom as geom 

47import lsst.meas.base.tests 

48from lsst.sphgeom import Box 

49import lsst.utils.tests 

50import utils_tests 

51 

52 

53def _roundTripThroughApdb(objects, sources, forcedSources, dateTime): 

54 """Run object and source catalogs through the Apdb to get the correct 

55 table schemas. 

56 

57 Parameters 

58 ---------- 

59 objects : `pandas.DataFrame` 

60 Set of test DiaObjects to round trip. 

61 sources : `pandas.DataFrame` 

62 Set of test DiaSources to round trip. 

63 forcedSources : `pandas.DataFrame` 

64 Set of test DiaForcedSources to round trip. 

65 dateTime : `lsst.daf.base.DateTime` 

66 Time for the Apdb. 

67 

68 Returns 

69 ------- 

70 objects : `pandas.DataFrame` 

71 Round tripped objects. 

72 sources : `pandas.DataFrame` 

73 Round tripped sources. 

74 """ 

75 tmpFile = tempfile.NamedTemporaryFile() 

76 

77 apdbConfig = ApdbSqlConfig() 

78 apdbConfig.db_url = "sqlite:///" + tmpFile.name 

79 apdbConfig.dia_object_index = "baseline" 

80 apdbConfig.dia_object_columns = [] 

81 

82 Apdb.makeSchema(apdbConfig) 

83 apdb = ApdbSql(config=apdbConfig) 

84 

85 wholeSky = Box.full() 

86 diaObjects = pd.concat([apdb.getDiaObjects(wholeSky), objects]) 

87 diaSources = pd.concat( 

88 [apdb.getDiaSources(wholeSky, [], dateTime), sources]) 

89 diaForcedSources = pd.concat( 

90 [apdb.getDiaForcedSources(wholeSky, [], dateTime), forcedSources]) 

91 

92 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources) 

93 

94 diaObjects = apdb.getDiaObjects(wholeSky) 

95 diaSources = apdb.getDiaSources(wholeSky, 

96 np.unique(diaObjects["diaObjectId"]), 

97 dateTime) 

98 diaForcedSources = apdb.getDiaForcedSources( 

99 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime) 

100 

101 diaObjects.set_index("diaObjectId", drop=False, inplace=True) 

102 diaSources.set_index(["diaObjectId", "band", "diaSourceId"], 

103 drop=False, 

104 inplace=True) 

105 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True) 

106 

107 return (diaObjects, diaSources, diaForcedSources) 

108 

109 

110def mock_alert(alert_id): 

111 """Generate a minimal mock alert. 

112 """ 

113 return { 

114 "alertId": alert_id, 

115 "diaSource": { 

116 "midpointMjdTai": 5, 

117 "diaSourceId": 4, 

118 "ccdVisitId": 2, 

119 "band": 'g', 

120 "ra": 12.5, 

121 "dec": -16.9, 

122 # These types are 32-bit floats in the avro schema, so we have to 

123 # make them that type here, so that they round trip appropriately. 

124 "x": np.float32(15.7), 

125 "y": np.float32(89.8), 

126 "apFlux": np.float32(54.85), 

127 "apFluxErr": np.float32(70.0), 

128 "snr": np.float32(6.7), 

129 "psfFlux": np.float32(700.0), 

130 "psfFluxErr": np.float32(90.0), 

131 "flags": 12345, 

132 } 

133 } 

134 

135 

136def _deserialize_alert(alert_bytes): 

137 """Deserialize an alert message from Kafka. 

138 

139 Parameters 

140 ---------- 

141 alert_bytes : `bytes` 

142 Binary-encoding serialized Avro alert, including Confluent Wire 

143 Format prefix. 

144 

145 Returns 

146 ------- 

147 alert : `dict` 

148 An alert payload. 

149 """ 

150 schema = alertPack.Schema.from_uri(str(alertPack.get_uri_to_latest_schema())) 

151 content_bytes = io.BytesIO(alert_bytes[5:]) 

152 

153 return fastavro.schemaless_reader(content_bytes, schema.definition) 

154 

155 

156class TestPackageAlerts(lsst.utils.tests.TestCase): 

157 

158 def setUp(self): 

159 patcher = patch.dict(os.environ, {"AP_KAFKA_PRODUCER_PASSWORD": "fake_password", 

160 "AP_KAFKA_PRODUCER_USERNAME": "fake_username", 

161 "AP_KAFKA_SERVER": "fake_server", 

162 "AP_KAFKA_TOPIC": "fake_topic"}) 

163 self.environ = patcher.start() 

164 self.addCleanup(patcher.stop) 

165 np.random.seed(1234) 

166 self.cutoutSize = 35 

167 self.center = lsst.geom.Point2D(50.1, 49.8) 

168 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30), 

169 lsst.geom.Extent2I(140, 160)) 

170 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox) 

171 self.dataset.addSource(100000.0, self.center) 

172 exposure, catalog = self.dataset.realize( 

173 10.0, 

174 self.dataset.makeMinimalSchema(), 

175 randomSeed=0) 

176 self.exposure = exposure 

177 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector 

178 self.exposure.setDetector(detector) 

179 

180 visit = afwImage.VisitInfo( 

181 exposureTime=200., 

182 date=dafBase.DateTime("2014-05-13T17:00:00.000000000", 

183 dafBase.DateTime.Timescale.TAI)) 

184 self.exposure.info.id = 1234 

185 self.exposure.info.setVisitInfo(visit) 

186 

187 self.exposure.setFilter( 

188 afwImage.FilterLabel(band='g', physical="g.MP9401")) 

189 

190 diaObjects = utils_tests.makeDiaObjects(2, self.exposure) 

191 diaSourceHistory = utils_tests.makeDiaSources(10, 

192 diaObjects[ 

193 "diaObjectId"], 

194 self.exposure) 

195 diaForcedSources = utils_tests.makeDiaForcedSources(10, 

196 diaObjects[ 

197 "diaObjectId"], 

198 self.exposure) 

199 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb( 

200 diaObjects, 

201 diaSourceHistory, 

202 diaForcedSources, 

203 self.exposure.visitInfo.date) 

204 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True) 

205 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True) 

206 self.diaForcedSources.replace(to_replace=[None], value=np.nan, 

207 inplace=True) 

208 diaSourceHistory["programId"] = 0 

209 

210 self.diaSources = diaSourceHistory.loc[[(1, "g", 9), (2, "g", 10)], :] 

211 self.diaSources["bboxSize"] = self.cutoutSize 

212 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9), 

213 (2, "g", 10)]) 

214 

215 self.cutoutWcs = wcs.WCS(naxis=2) 

216 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]] 

217 self.cutoutWcs.wcs.crval = [ 

218 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(), 

219 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()] 

220 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix() 

221 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] 

222 

223 def testCreateExtent(self): 

224 """Test the extent creation for the cutout bbox. 

225 """ 

226 packConfig = PackageAlertsConfig() 

227 # Just create a minimum less than the default cutout. 

228 packConfig.minCutoutSize = self.cutoutSize - 5 

229 packageAlerts = PackageAlertsTask(config=packConfig) 

230 extent = packageAlerts.createDiaSourceExtent( 

231 packConfig.minCutoutSize - 5) 

232 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize, 

233 packConfig.minCutoutSize)) 

234 # Test that the cutout size is correct. 

235 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize) 

236 self.assertTrue(extent == geom.Extent2I(self.cutoutSize, 

237 self.cutoutSize)) 

238 

239 def testCreateCcdDataCutout(self): 

240 """Test that the data is being extracted into the CCDData cutout 

241 correctly. 

242 """ 

243 packageAlerts = PackageAlertsTask() 

244 

245 diaSrcId = 1234 

246 ccdData = packageAlerts.createCcdDataCutout( 

247 self.exposure, 

248 self.exposure.getWcs().getSkyOrigin(), 

249 self.exposure.getBBox().getDimensions(), 

250 self.exposure.getPhotoCalib(), 

251 diaSrcId) 

252 calibExposure = self.exposure.getPhotoCalib().calibrateImage( 

253 self.exposure.getMaskedImage()) 

254 

255 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd, 

256 self.cutoutWcs.wcs.cd) 

257 self.assertFloatsAlmostEqual(ccdData.data, 

258 calibExposure.getImage().array) 

259 

260 ccdData = packageAlerts.createCcdDataCutout( 

261 self.exposure, 

262 geom.SpherePoint(0, 0, geom.degrees), 

263 self.exposure.getBBox().getDimensions(), 

264 self.exposure.getPhotoCalib(), 

265 diaSrcId) 

266 self.assertTrue(ccdData is None) 

267 

268 def testMakeLocalTransformMatrix(self): 

269 """Test that the local WCS approximation is correct. 

270 """ 

271 packageAlerts = PackageAlertsTask() 

272 

273 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

274 cutout = self.exposure.getCutout(sphPoint, 

275 geom.Extent2I(self.cutoutSize, 

276 self.cutoutSize)) 

277 cd = packageAlerts.makeLocalTransformMatrix( 

278 cutout.getWcs(), self.center, sphPoint) 

279 self.assertFloatsAlmostEqual( 

280 cd, 

281 cutout.getWcs().getCdMatrix(), 

282 rtol=1e-11, 

283 atol=1e-11) 

284 

285 def testStreamCcdDataToBytes(self): 

286 """Test round tripping an CCDData cutout to bytes and back. 

287 """ 

288 packageAlerts = PackageAlertsTask() 

289 

290 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

291 cutout = self.exposure.getCutout(sphPoint, 

292 geom.Extent2I(self.cutoutSize, 

293 self.cutoutSize)) 

294 cutoutCcdData = CCDData( 

295 data=cutout.getImage().array, 

296 wcs=self.cutoutWcs, 

297 unit="adu") 

298 

299 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData) 

300 with io.BytesIO(cutoutBytes) as bytesIO: 

301 cutoutFromBytes = CCDData.read(bytesIO, format="fits") 

302 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data) 

303 

304 def testMakeAlertDict(self): 

305 """Test stripping data from the various data products and into a 

306 dictionary "alert". 

307 """ 

308 packageAlerts = PackageAlertsTask() 

309 alertId = 1234 

310 

311 for srcIdx, diaSource in self.diaSources.iterrows(): 

312 sphPoint = geom.SpherePoint(diaSource["ra"], 

313 diaSource["dec"], 

314 geom.degrees) 

315 cutout = self.exposure.getCutout(sphPoint, 

316 geom.Extent2I(self.cutoutSize, 

317 self.cutoutSize)) 

318 ccdCutout = packageAlerts.createCcdDataCutout( 

319 cutout, 

320 sphPoint, 

321 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

322 cutout.getPhotoCalib(), 

323 1234) 

324 cutoutBytes = packageAlerts.streamCcdDataToBytes( 

325 ccdCutout) 

326 objSources = self.diaSourceHistory.loc[srcIdx[0]] 

327 objForcedSources = self.diaForcedSources.loc[srcIdx[0]] 

328 alert = packageAlerts.makeAlertDict( 

329 alertId, 

330 diaSource, 

331 self.diaObjects.loc[srcIdx[0]], 

332 objSources, 

333 objForcedSources, 

334 ccdCutout, 

335 ccdCutout) 

336 self.assertEqual(len(alert), 9) 

337 

338 self.assertEqual(alert["alertId"], alertId) 

339 self.assertEqual(alert["diaSource"], diaSource.to_dict()) 

340 self.assertEqual(alert["cutoutDifference"], 

341 cutoutBytes) 

342 self.assertEqual(alert["cutoutTemplate"], 

343 cutoutBytes) 

344 

345 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

346 def test_produceAlerts_empty_password(self): 

347 """ Test that produceAlerts raises if the password is empty or missing. 

348 """ 

349 self.environ['AP_KAFKA_PRODUCER_PASSWORD'] = "" 

350 with self.assertRaisesRegex(ValueError, "Kafka password"): 

351 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

352 PackageAlertsTask(config=packConfig) 

353 

354 del self.environ['AP_KAFKA_PRODUCER_PASSWORD'] 

355 with self.assertRaisesRegex(ValueError, "Kafka password"): 

356 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

357 PackageAlertsTask(config=packConfig) 

358 

359 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

360 def test_produceAlerts_empty_username(self): 

361 """ Test that produceAlerts raises if the username is empty or missing. 

362 """ 

363 self.environ['AP_KAFKA_PRODUCER_USERNAME'] = "" 

364 with self.assertRaisesRegex(ValueError, "Kafka username"): 

365 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

366 PackageAlertsTask(config=packConfig) 

367 

368 del self.environ['AP_KAFKA_PRODUCER_USERNAME'] 

369 with self.assertRaisesRegex(ValueError, "Kafka username"): 

370 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

371 PackageAlertsTask(config=packConfig) 

372 

373 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

374 def test_produceAlerts_empty_server(self): 

375 """ Test that produceAlerts raises if the server is empty or missing. 

376 """ 

377 self.environ['AP_KAFKA_SERVER'] = "" 

378 with self.assertRaisesRegex(ValueError, "Kafka server"): 

379 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

380 PackageAlertsTask(config=packConfig) 

381 

382 del self.environ['AP_KAFKA_SERVER'] 

383 with self.assertRaisesRegex(ValueError, "Kafka server"): 

384 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

385 PackageAlertsTask(config=packConfig) 

386 

387 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

388 def test_produceAlerts_empty_topic(self): 

389 """ Test that produceAlerts raises if the topic is empty or missing. 

390 """ 

391 self.environ['AP_KAFKA_TOPIC'] = "" 

392 with self.assertRaisesRegex(ValueError, "Kafka topic"): 

393 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

394 PackageAlertsTask(config=packConfig) 

395 

396 del self.environ['AP_KAFKA_TOPIC'] 

397 with self.assertRaisesRegex(ValueError, "Kafka topic"): 

398 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

399 PackageAlertsTask(config=packConfig) 

400 

401 @patch('confluent_kafka.Producer') 

402 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

403 def test_produceAlerts_success(self, mock_producer): 

404 """ Test that produceAlerts calls the producer on all provided alerts 

405 when the alerts are all under the batch size limit. 

406 """ 

407 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

408 packageAlerts = PackageAlertsTask(config=packConfig) 

409 alerts = [mock_alert(1), mock_alert(2)] 

410 ccdVisitId = 123 

411 

412 # Create a variable and assign it an instance of the patched kafka producer 

413 producer_instance = mock_producer.return_value 

414 producer_instance.produce = Mock() 

415 producer_instance.flush = Mock() 

416 packageAlerts.produceAlerts(alerts, ccdVisitId) 

417 

418 self.assertEqual(producer_instance.produce.call_count, len(alerts)) 

419 self.assertEqual(producer_instance.flush.call_count, len(alerts)+1) 

420 

421 @patch('confluent_kafka.Producer') 

422 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

423 def test_produceAlerts_one_failure(self, mock_producer): 

424 """ Test that produceAlerts correctly fails on one alert 

425 and is writing the failure to disk. 

426 """ 

427 counter = 0 

428 

429 def mock_produce(*args, **kwargs): 

430 nonlocal counter 

431 counter += 1 

432 if counter == 2: 

433 raise KafkaException 

434 else: 

435 return 

436 

437 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

438 packageAlerts = PackageAlertsTask(config=packConfig) 

439 

440 patcher = patch("builtins.open") 

441 patch_open = patcher.start() 

442 alerts = [mock_alert(1), mock_alert(2), mock_alert(3)] 

443 ccdVisitId = 123 

444 

445 producer_instance = mock_producer.return_value 

446 producer_instance.produce = Mock(side_effect=mock_produce) 

447 producer_instance.flush = Mock() 

448 

449 packageAlerts.produceAlerts(alerts, ccdVisitId) 

450 

451 self.assertEqual(producer_instance.produce.call_count, len(alerts)) 

452 self.assertEqual(patch_open.call_count, 1) 

453 self.assertIn("123_2.avro", patch_open.call_args.args[0]) 

454 # Because one produce raises, we call flush one fewer times than in the success 

455 # test above. 

456 self.assertEqual(producer_instance.flush.call_count, len(alerts)) 

457 patcher.stop() 

458 

459 def testRun_without_produce(self): 

460 """Test the run method of package alerts with produce set to False and 

461 doWriteAlerts set to true. 

462 """ 

463 

464 packConfig = PackageAlertsConfig(doWriteAlerts=True) 

465 tempdir = tempfile.mkdtemp(prefix='alerts') 

466 packConfig.alertWriteLocation = tempdir 

467 packageAlerts = PackageAlertsTask(config=packConfig) 

468 

469 packageAlerts.run(self.diaSources, 

470 self.diaObjects, 

471 self.diaSourceHistory, 

472 self.diaForcedSources, 

473 self.exposure, 

474 self.exposure) 

475 

476 ccdVisitId = self.exposure.info.id 

477 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f: 

478 writer_schema, data_stream = \ 

479 packageAlerts.alertSchema.retrieve_alerts(f) 

480 data = list(data_stream) 

481 self.assertEqual(len(data), len(self.diaSources)) 

482 for idx, alert in enumerate(data): 

483 for key, value in alert["diaSource"].items(): 

484 if isinstance(value, float): 

485 if np.isnan(self.diaSources.iloc[idx][key]): 

486 self.assertTrue(np.isnan(value)) 

487 else: 

488 self.assertAlmostEqual( 

489 1 - value / self.diaSources.iloc[idx][key], 

490 0.) 

491 else: 

492 self.assertEqual(value, self.diaSources.iloc[idx][key]) 

493 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], 

494 alert["diaSource"]["dec"], 

495 geom.degrees) 

496 cutout = self.exposure.getCutout(sphPoint, 

497 geom.Extent2I(self.cutoutSize, 

498 self.cutoutSize)) 

499 ccdCutout = packageAlerts.createCcdDataCutout( 

500 cutout, 

501 sphPoint, 

502 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

503 cutout.getPhotoCalib(), 

504 1234) 

505 self.assertEqual(alert["cutoutDifference"], 

506 packageAlerts.streamCcdDataToBytes(ccdCutout)) 

507 

508 shutil.rmtree(tempdir) 

509 

510 @patch.object(PackageAlertsTask, 'produceAlerts') 

511 @patch('confluent_kafka.Producer') 

512 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

513 def testRun_with_produce(self, mock_produceAlerts, mock_producer): 

514 """Test that packageAlerts calls produceAlerts when doProduceAlerts 

515 is set to True. 

516 """ 

517 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

518 packageAlerts = PackageAlertsTask(config=packConfig) 

519 

520 packageAlerts.run(self.diaSources, 

521 self.diaObjects, 

522 self.diaSourceHistory, 

523 self.diaForcedSources, 

524 self.exposure, 

525 self.exposure) 

526 

527 self.assertEqual(mock_produceAlerts.call_count, 1) 

528 

529 def test_serialize_alert_round_trip(self): 

530 """Test that values in the alert packet exactly round trip. 

531 """ 

532 ConfigClass = PackageAlertsConfig() 

533 packageAlerts = PackageAlertsTask(config=ConfigClass) 

534 

535 alert = mock_alert(1) 

536 serialized = PackageAlertsTask._serializeAlert(packageAlerts, alert) 

537 deserialized = _deserialize_alert(serialized) 

538 

539 for field in alert['diaSource']: 

540 self.assertEqual(alert['diaSource'][field], deserialized['diaSource'][field]) 

541 self.assertEqual(1, deserialized["alertId"]) 

542 

543 

544class MemoryTester(lsst.utils.tests.MemoryTestCase): 

545 pass 

546 

547 

548def setup_module(module): 

549 lsst.utils.tests.init() 

550 

551 

552if __name__ == "__main__": 552 ↛ 553line 552 didn't jump to line 553, because the condition on line 552 was never true

553 lsst.utils.tests.init() 

554 unittest.main()