Coverage for tests / test_packageAlerts.py: 19%

363 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-07 08:39 +0000

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import io 

23import os 

24 

25import numpy as np 

26import pandas as pd 

27import tempfile 

28import unittest 

29from unittest.mock import patch, Mock 

30from astropy import wcs 

31from astropy.nddata import CCDData 

32import fastavro 

33try: 

34 import confluent_kafka 

35 from confluent_kafka import KafkaException 

36except ImportError: 

37 confluent_kafka = None 

38 

39import lsst.alert.packet as alertPack 

40from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask 

41from lsst.ap.association.utils import readSchemaFromApdb 

42from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

43import lsst.afw.image as afwImage 

44from lsst.daf.base import DateTime 

45from lsst.dax.apdb import Apdb, ApdbSql 

46import lsst.geom as geom 

47import lsst.meas.base.tests 

48from lsst.sphgeom import Box 

49import lsst.utils.tests 

50from lsst.pipe.tasks.functors import LocalWcs 

51from lsst.pipe.tasks.schemaUtils import convertDataFrameToSdmSchema 

52import utils_tests 

53 

54 

55def _roundTripThroughApdb(objects, sources, forcedSources, dateTime): 

56 """Run object and source catalogs through the Apdb to get the correct 

57 table schemas. 

58 

59 Parameters 

60 ---------- 

61 objects : `pandas.DataFrame` 

62 Set of test DiaObjects to round trip. 

63 sources : `pandas.DataFrame` 

64 Set of test DiaSources to round trip. 

65 forcedSources : `pandas.DataFrame` 

66 Set of test DiaForcedSources to round trip. 

67 dateTime : `astropy.time.Time` 

68 Time for the Apdb. 

69 

70 Returns 

71 ------- 

72 objects : `pandas.DataFrame` 

73 Round tripped objects. 

74 sources : `pandas.DataFrame` 

75 Round tripped sources. 

76 """ 

77 with tempfile.NamedTemporaryFile() as tmpFile: 

78 apdbConfig = ApdbSql.init_database(db_url="sqlite:///" + tmpFile.name) 

79 apdb = Apdb.from_config(apdbConfig) 

80 

81 wholeSky = Box.full() 

82 loadedObjects = apdb.getDiaObjects(wholeSky) 

83 if loadedObjects.empty: 

84 diaObjects = objects 

85 else: 

86 diaObjects = pd.concat([loadedObjects, objects]) 

87 loadedDiaSources = apdb.getDiaSources(wholeSky, [], dateTime) 

88 if loadedDiaSources.empty: 

89 diaSources = sources 

90 else: 

91 diaSources = pd.concat([loadedDiaSources, sources]) 

92 loadedDiaForcedSources = apdb.getDiaForcedSources(wholeSky, [], dateTime) 

93 if loadedDiaForcedSources.empty: 

94 diaForcedSources = forcedSources 

95 else: 

96 diaForcedSources = pd.concat([loadedDiaForcedSources, forcedSources]) 

97 

98 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources) 

99 

100 diaObjects = apdb.getDiaObjects(wholeSky) 

101 diaSources = apdb.getDiaSources(wholeSky, 

102 np.unique(diaObjects["diaObjectId"]), 

103 dateTime) 

104 diaForcedSources = apdb.getDiaForcedSources( 

105 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime) 

106 

107 diaObjects.set_index("diaObjectId", drop=False, inplace=True) 

108 diaSources.set_index(["diaObjectId", "band", "diaSourceId"], 

109 drop=False, 

110 inplace=True) 

111 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True) 

112 

113 # apply SDM type standardization to catch pandas typing issues 

114 schema = readSchemaFromApdb(apdb) 

115 diaObjects = convertDataFrameToSdmSchema(schema, diaObjects, tableName="DiaObject", skipIndex=True) 

116 diaSources = convertDataFrameToSdmSchema(schema, diaSources, tableName="DiaSource", skipIndex=True) 

117 diaForcedSources = convertDataFrameToSdmSchema(schema, diaForcedSources, tableName="DiaForcedSource", 

118 skipIndex=True) 

119 

120 return (diaObjects, diaSources, diaForcedSources) 

121 

122 

123VISIT = 2 

124DETECTOR = 42 

125 

126 

127def mock_alert(dia_source_id): 

128 """Generate a minimal mock alert. 

129 """ 

130 return { 

131 "diaSourceId": dia_source_id, 

132 "diaSource": { 

133 "midpointMjdTai": 5, 

134 "diaSourceId": 1234, 

135 "visit": VISIT, 

136 "detector": DETECTOR, 

137 "band": 'g', 

138 "ra": 12.5, 

139 "dec": -16.9, 

140 # These types are 32-bit floats in the avro schema, so we have to 

141 # make them that type here, so that they round trip appropriately. 

142 "x": np.float32(15.7), 

143 "y": np.float32(89.8), 

144 "apFlux": np.float32(54.85), 

145 "apFluxErr": np.float32(70.0), 

146 "snr": np.float32(6.7), 

147 "psfFlux": np.float32(700.0), 

148 "psfFluxErr": np.float32(90.0), 

149 "psfNdata": int(25), 

150 "dipoleNdata": int(43), 

151 "trailNdata": int(5), 

152 "bboxSize": int(50), 

153 # unlike in transformDiaSourceCatalog.py we need a timezone-aware 

154 # version because mock_alert does not go through pandas 

155 "timeProcessedMjdTai": DateTime.now().get(system=DateTime.MJD, scale=DateTime.TAI) 

156 } 

157 } 

158 

159 

160def mock_ss_alert(dia_source_id, ss_object_id): 

161 """Generate a minimal mock alert. 

162 """ 

163 alert = mock_alert(dia_source_id) 

164 alert['mpc_orbits'] = { 

165 'id': 42, 

166 'designation': '2020 AH11', # a string-typed field 

167 'q': np.float64(0.99999), # a double-typed field 

168 'unpacked_primary_provisional_designation': '2020 AH11', 

169 'packed_primary_provisional_designation': 'J201AH1', 

170 

171 } 

172 alert['ssSource'] = { 

173 'diaSourceId': dia_source_id, 

174 'ssObjectId': ss_object_id, 

175 'eclLambda': np.float64(3.141592), # a double-typed field 

176 'eclBeta': np.float64(3.141592), # a double-typed field 

177 'galLon': np.float64(3.141592), # a double-typed field 

178 'galLat': np.float64(3.141592), # a double-typed field 

179 'helioRange': np.float32(3.141592), # a float-typed field 

180 } 

181 return alert 

182 

183 

184def _deserialize_alert(alert_bytes): 

185 """Deserialize an alert message from Kafka. 

186 

187 Parameters 

188 ---------- 

189 alert_bytes : `bytes` 

190 Binary-encoding serialized Avro alert, including Confluent Wire 

191 Format prefix. 

192 

193 Returns 

194 ------- 

195 alert : `dict` 

196 An alert payload. 

197 """ 

198 schema = alertPack.Schema.from_uri(str(alertPack.get_uri_to_latest_schema())) 

199 content_bytes = io.BytesIO(alert_bytes[5:]) 

200 

201 return fastavro.schemaless_reader(content_bytes, schema.definition) 

202 

203 

204class TestPackageAlerts(lsst.utils.tests.TestCase): 

205 def setUp(self): 

206 # Create an instance of random generator with fixed seed. 

207 rng = np.random.default_rng(1234) 

208 

209 patcher = patch.dict(os.environ, {"AP_KAFKA_PRODUCER_PASSWORD": "fake_password", 

210 "AP_KAFKA_PRODUCER_USERNAME": "fake_username", 

211 "AP_KAFKA_SERVER": "fake_server", 

212 "AP_KAFKA_TOPIC": "fake_topic"}) 

213 self.environ = patcher.start() 

214 self.addCleanup(patcher.stop) 

215 self.cutoutSize = 35 

216 self.center = lsst.geom.Point2D(50.1, 49.8) 

217 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30), 

218 lsst.geom.Extent2I(140, 160)) 

219 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox) 

220 self.dataset.addSource(100000.0, self.center) 

221 exposure, catalog = self.dataset.realize( 

222 10.0, 

223 self.dataset.makeMinimalSchema(), 

224 randomSeed=1234) 

225 self.exposure = exposure 

226 detector = DetectorWrapper(id=DETECTOR, bbox=exposure.getBBox()).detector 

227 self.exposure.setDetector(detector) 

228 

229 visit = afwImage.VisitInfo( 

230 id=VISIT, 

231 exposureTime=200., 

232 date=DateTime("2014-05-13T17:00:00.000000000", 

233 DateTime.Timescale.TAI), 

234 boresightRotAngle=geom.Angle(0.785398)) 

235 self.exposure.info.id = 1234 

236 self.exposure.info.setVisitInfo(visit) 

237 

238 self.exposure.setFilter( 

239 afwImage.FilterLabel(band='g', physical="g.MP9401")) 

240 

241 diaObjects = utils_tests.makeDiaObjects(2, self.exposure, rng) 

242 diaSourceHistory = utils_tests.makeDiaSources( 

243 10, diaObjects["diaObjectId"].to_numpy(), self.exposure, rng) 

244 diaForcedSources = utils_tests.makeDiaForcedSources( 

245 10, diaObjects["diaObjectId"].to_numpy(), self.exposure, rng) 

246 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb( 

247 diaObjects, 

248 diaSourceHistory, 

249 diaForcedSources, 

250 self.exposure.visitInfo.date.toAstropy()) 

251 diaSourceHistory["programId"] = 0 

252 

253 self.diaSources = diaSourceHistory.loc[[(1, "g", 9), (2, "g", 10)], :] 

254 self.diaSources["bboxSize"] = self.cutoutSize 

255 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9), 

256 (2, "g", 10)]) 

257 

258 self.cutoutWcs = wcs.WCS(naxis=2) 

259 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]] 

260 self.cutoutWcs.wcs.crval = [ 

261 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(), 

262 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()] 

263 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix() 

264 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] 

265 

266 def testCreateExtentMinimum(self): 

267 """Test the extent creation for the cutout bbox returns a cutout with 

268 the minimum cutouut size. 

269 """ 

270 packConfig = PackageAlertsConfig() 

271 # Just create a minimum less than the default cutout size. 

272 packConfig.minCutoutSize = self.cutoutSize - 5 

273 packageAlerts = PackageAlertsTask(config=packConfig) 

274 extent = packageAlerts.createDiaSourceExtent( 

275 packConfig.minCutoutSize - 5) 

276 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize, 

277 packConfig.minCutoutSize)) 

278 # Test that the cutout size is correctly increased. 

279 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize) 

280 self.assertTrue(extent == geom.Extent2I(self.cutoutSize, 

281 self.cutoutSize)) 

282 

283 def testCreateExtentMaximum(self): 

284 """Test the extent creation for the cutout bbox returns a cutout with 

285 the maximum cutout size. 

286 """ 

287 packConfig = PackageAlertsConfig() 

288 # Just create a maximum more than the default cutout size. 

289 packConfig.maxCutoutSize = self.cutoutSize + 5 

290 packageAlerts = PackageAlertsTask(config=packConfig) 

291 extent = packageAlerts.createDiaSourceExtent( 

292 packConfig.maxCutoutSize + 5) 

293 self.assertTrue(extent == geom.Extent2I(packConfig.maxCutoutSize, 

294 packConfig.maxCutoutSize)) 

295 # Test that the cutout size is correctly reduced. 

296 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize) 

297 self.assertTrue(extent == geom.Extent2I(self.cutoutSize, 

298 self.cutoutSize)) 

299 

300 def testCreateCcdDataCutout(self): 

301 """Test that the data is being extracted into the CCDData cutout 

302 correctly. 

303 """ 

304 packageAlerts = PackageAlertsTask() 

305 

306 diaSrcId = 1234 

307 ccdData = packageAlerts.createCcdDataCutout( 

308 self.exposure, 

309 self.exposure.getWcs().getSkyOrigin(), 

310 self.exposure.getWcs().getPixelOrigin(), 

311 self.exposure.getBBox().getDimensions(), 

312 self.exposure.getPhotoCalib(), 

313 diaSrcId) 

314 calibExposure = self.exposure.getPhotoCalib().calibrateImage( 

315 self.exposure.getMaskedImage()) 

316 

317 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd, 

318 self.cutoutWcs.wcs.cd) 

319 self.assertFloatsAlmostEqual(ccdData.data, 

320 calibExposure.getImage().array) 

321 self.assertFloatsAlmostEqual(ccdData.psf, 

322 self.exposure.psf.computeKernelImage(self.center).array) 

323 

324 ccdData = packageAlerts.createCcdDataCutout( 

325 self.exposure, 

326 geom.SpherePoint(0, 0, geom.degrees), 

327 geom.Point2D(-100, -100), 

328 self.exposure.getBBox().getDimensions(), 

329 self.exposure.getPhotoCalib(), 

330 diaSrcId) 

331 self.assertTrue(ccdData is None) 

332 

333 def testMakeLocalTransformMatrix(self): 

334 """Test that the local WCS approximation is correct. 

335 """ 

336 packageAlerts = PackageAlertsTask() 

337 

338 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

339 cutout = self.exposure.getCutout(sphPoint, 

340 geom.Extent2I(self.cutoutSize, 

341 self.cutoutSize)) 

342 cd = packageAlerts.makeLocalTransformMatrix( 

343 cutout.getWcs(), self.center, sphPoint) 

344 self.assertFloatsAlmostEqual( 

345 cd, 

346 cutout.getWcs().getCdMatrix(), 

347 rtol=5e-10, 

348 atol=5e-10) 

349 

350 def testStreamCcdDataToBytes(self): 

351 """Test round tripping an CCDData cutout to bytes and back. 

352 """ 

353 packageAlerts = PackageAlertsTask() 

354 

355 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

356 cutout = self.exposure.getCutout(sphPoint, 

357 geom.Extent2I(self.cutoutSize, 

358 self.cutoutSize)) 

359 cutoutCcdData = CCDData( 

360 data=cutout.getImage().array, 

361 wcs=self.cutoutWcs, 

362 unit="adu") 

363 

364 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData) 

365 with io.BytesIO(cutoutBytes) as bytesIO: 

366 cutoutFromBytes = CCDData.read(bytesIO, format="fits") 

367 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data) 

368 

369 def testMakeAlertDict(self): 

370 """Test stripping data from the various data products and into a 

371 dictionary "alert". 

372 """ 

373 packageAlerts = PackageAlertsTask() 

374 dia_source_id = 1234 

375 

376 for srcIdx, diaSource in self.diaSources.iterrows(): 

377 sphPoint = geom.SpherePoint(diaSource["ra"], 

378 diaSource["dec"], 

379 geom.degrees) 

380 pixelPoint = geom.Point2D(diaSource["x"], diaSource["y"]) 

381 cutout = self.exposure.getCutout(sphPoint, 

382 geom.Extent2I(self.cutoutSize, 

383 self.cutoutSize)) 

384 ccdCutout = packageAlerts.createCcdDataCutout( 

385 cutout, 

386 sphPoint, 

387 pixelPoint, 

388 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

389 cutout.getPhotoCalib(), 

390 1234, 

391 rotPa=cutout.visitInfo.boresightRotAngle.asDegrees()) 

392 cutoutBytes = packageAlerts.streamCcdDataToBytes( 

393 ccdCutout) 

394 objSources = self.diaSourceHistory.loc[srcIdx[0]] 

395 objForcedSources = self.diaForcedSources.loc[srcIdx[0]] 

396 alert = packageAlerts.makeAlertDict( 

397 dia_source_id, 

398 self.exposure.visitInfo.getObservationReason(), 

399 self.exposure.visitInfo.getObject(), 

400 diaSource, 

401 self.diaObjects.loc[srcIdx[0]], 

402 objSources, 

403 objForcedSources, 

404 ccdCutout, 

405 ccdCutout, 

406 ccdCutout) 

407 self.assertEqual(len(alert), 13) 

408 

409 self.assertEqual(alert["diaSourceId"], dia_source_id) 

410 self.assertEqual(alert["diaSource"], diaSource.to_dict()) 

411 self.assertIsNone(alert["observation_reason"]) 

412 self.assertIsNone(alert["target_name"]) 

413 self.assertEqual(alert["cutoutDifference"], 

414 cutoutBytes) 

415 self.assertEqual(alert["cutoutScience"], 

416 cutoutBytes) 

417 self.assertEqual(alert["cutoutTemplate"], 

418 cutoutBytes) 

419 science_cutout = CCDData.read(io.BytesIO(alert["cutoutScience"]), 

420 format="fits") 

421 template_cutout = CCDData.read(io.BytesIO(alert["cutoutTemplate"]), 

422 format="fits") 

423 self.assertAlmostEqual(science_cutout.header["ROTPA"], 

424 template_cutout.header["ROTPA"]) 

425 

426 def testMakeAlertDictSchedulerFields(self): 

427 """Test non-null scheduler fields pass through as expected. 

428 

429 """ 

430 packageAlerts = PackageAlertsTask() 

431 dia_source_id = 1234 

432 

433 for srcIdx, diaSource in self.diaSources.iterrows(): 

434 objSources = self.diaSourceHistory.loc[srcIdx[0]] 

435 objForcedSources = self.diaForcedSources.loc[srcIdx[0]] 

436 obs_reason = f"obs_reason_{srcIdx}", 

437 target = f"target_name_{srcIdx}", 

438 alert = packageAlerts.makeAlertDict( 

439 dia_source_id, 

440 obs_reason, 

441 target, 

442 diaSource, 

443 self.diaObjects.loc[srcIdx[0]], 

444 objSources, 

445 objForcedSources, 

446 None, 

447 None, 

448 None) 

449 self.assertEqual(len(alert), 13) 

450 

451 self.assertEqual(alert["observation_reason"], obs_reason) 

452 self.assertEqual(alert["target_name"], target) 

453 

454 def testCutoutRotpa(self): 

455 """Test that the ROTPA header keyword matches the boresightRotAngle from visitInfo. 

456 """ 

457 packageAlerts = PackageAlertsTask() 

458 

459 # Create a cutout using existing test exposure 

460 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

461 cutout = self.exposure.getCutout(sphPoint, 

462 geom.Extent2I(self.cutoutSize, 

463 self.cutoutSize)) 

464 

465 # Create CCDData cutout 

466 ccdCutout = packageAlerts.createCcdDataCutout( 

467 cutout, 

468 sphPoint, 

469 self.center, 

470 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

471 cutout.getPhotoCalib(), 

472 1234, 

473 rotPa=cutout.visitInfo.boresightRotAngle.asDegrees()) 

474 

475 wcs = self.exposure.wcs 

476 cd_matrix = wcs.getCdMatrix() # This gets the CD matrix elements 

477 angle_rad = LocalWcs.computePositionAngle([0], 

478 [cd_matrix[0, 0]], # CD1_1 

479 [cd_matrix[0, 1]], # CD1_2 

480 [cd_matrix[1, 0]], # CD2_1 

481 [cd_matrix[1, 1]] # CD2_2 

482 ) 

483 wcs_pa = geom.Angle(angle_rad.iloc[0], geom.radians).asDegrees() 

484 

485 cutoutBytes = packageAlerts.streamCcdDataToBytes(ccdCutout) 

486 with io.BytesIO(cutoutBytes) as bytesIO: 

487 cutoutFromBytes = CCDData.read(bytesIO, format="fits") 

488 

489 self.assertIn('ROTPA', cutoutFromBytes.header) 

490 self.assertAlmostEqual( 

491 cutoutFromBytes.header['ROTPA'], 

492 self.exposure.visitInfo.boresightRotAngle.asDegrees()) 

493 self.assertAlmostEqual( 

494 cutoutFromBytes.header['ROTPA'], 

495 wcs_pa, places=4) 

496 

497 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

498 def test_produceAlerts_empty_password(self): 

499 """ Test that produceAlerts raises if the password is empty or missing. 

500 """ 

501 self.environ['AP_KAFKA_PRODUCER_PASSWORD'] = "" 

502 with self.assertRaisesRegex(ValueError, "Kafka password"): 

503 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

504 PackageAlertsTask(config=packConfig) 

505 

506 del self.environ['AP_KAFKA_PRODUCER_PASSWORD'] 

507 with self.assertRaisesRegex(ValueError, "Kafka password"): 

508 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

509 PackageAlertsTask(config=packConfig) 

510 

511 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

512 def test_produceAlerts_empty_username(self): 

513 """ Test that produceAlerts raises if the username is empty or missing. 

514 """ 

515 self.environ['AP_KAFKA_PRODUCER_USERNAME'] = "" 

516 with self.assertRaisesRegex(ValueError, "Kafka username"): 

517 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

518 PackageAlertsTask(config=packConfig) 

519 

520 del self.environ['AP_KAFKA_PRODUCER_USERNAME'] 

521 with self.assertRaisesRegex(ValueError, "Kafka username"): 

522 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

523 PackageAlertsTask(config=packConfig) 

524 

525 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

526 def test_produceAlerts_empty_server(self): 

527 """ Test that produceAlerts raises if the server is empty or missing. 

528 """ 

529 self.environ['AP_KAFKA_SERVER'] = "" 

530 with self.assertRaisesRegex(ValueError, "Kafka server"): 

531 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

532 PackageAlertsTask(config=packConfig) 

533 

534 del self.environ['AP_KAFKA_SERVER'] 

535 with self.assertRaisesRegex(ValueError, "Kafka server"): 

536 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

537 PackageAlertsTask(config=packConfig) 

538 

539 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

540 def test_produceAlerts_empty_topic(self): 

541 """ Test that produceAlerts raises if the topic is empty or missing. 

542 """ 

543 self.environ['AP_KAFKA_TOPIC'] = "" 

544 with self.assertRaisesRegex(ValueError, "Kafka topic"): 

545 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

546 PackageAlertsTask(config=packConfig) 

547 

548 del self.environ['AP_KAFKA_TOPIC'] 

549 with self.assertRaisesRegex(ValueError, "Kafka topic"): 

550 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

551 PackageAlertsTask(config=packConfig) 

552 

553 @patch('confluent_kafka.Producer') 

554 @patch.object(PackageAlertsTask, '_server_check') 

555 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

556 def test_produceAlerts_success(self, mock_server_check, mock_producer): 

557 """ Test that produceAlerts calls the producer on all provided alerts 

558 when the alerts are all under the batch size limit. 

559 """ 

560 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

561 packageAlerts = PackageAlertsTask(config=packConfig) 

562 alerts = [mock_alert(1), mock_alert(2), mock_ss_alert(3, 3)] 

563 

564 # Create a variable and assign it an instance of the patched kafka producer 

565 producer_instance = mock_producer.return_value 

566 producer_instance.produce = Mock() 

567 producer_instance.flush = Mock() 

568 unix_midpoint = self.exposure.visitInfo.date.toAstropy().tai.unix 

569 exposure_time = self.exposure.visitInfo.exposureTime 

570 packageAlerts.produceAlerts(alerts, VISIT, DETECTOR, unix_midpoint, exposure_time) 

571 

572 self.assertEqual(mock_server_check.call_count, 1) 

573 self.assertEqual(producer_instance.produce.call_count, len(alerts)) 

574 self.assertEqual(producer_instance.flush.call_count, len(alerts)+1) 

575 

576 @patch('confluent_kafka.Producer') 

577 @patch.object(PackageAlertsTask, '_server_check') 

578 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

579 def test_produceAlerts_one_failure(self, mock_server_check, mock_producer): 

580 """ Test that produceAlerts correctly fails on one alert 

581 and is writing the failure to disk. 

582 """ 

583 counter = 0 

584 

585 def mock_produce(*args, **kwargs): 

586 nonlocal counter 

587 counter += 1 

588 if counter == 2: 

589 raise KafkaException 

590 else: 

591 return 

592 

593 packConfig = PackageAlertsConfig(doProduceAlerts=True, doWriteFailedAlerts=True) 

594 packageAlerts = PackageAlertsTask(config=packConfig) 

595 

596 patcher = patch("builtins.open") 

597 patch_open = patcher.start() 

598 alerts = [mock_alert(1), mock_alert(2), mock_alert(3), mock_ss_alert(4, 4)] 

599 unix_midpoint = self.exposure.visitInfo.date.toAstropy().tai.unix 

600 exposure_time = self.exposure.visitInfo.exposureTime 

601 

602 producer_instance = mock_producer.return_value 

603 producer_instance.produce = Mock(side_effect=mock_produce) 

604 producer_instance.flush = Mock() 

605 packageAlerts.produceAlerts(alerts, VISIT, DETECTOR, unix_midpoint, exposure_time) 

606 

607 self.assertEqual(mock_server_check.call_count, 1) 

608 self.assertEqual(producer_instance.produce.call_count, len(alerts)) 

609 self.assertEqual(patch_open.call_count, 1) 

610 self.assertIn(f"{VISIT}_{DETECTOR}_2.avro", patch_open.call_args.args[0]) 

611 # Because one produce raises, we call flush one fewer times than in the success 

612 # test above. 

613 self.assertEqual(producer_instance.flush.call_count, len(alerts)) 

614 patcher.stop() 

615 

616 @patch.object(PackageAlertsTask, '_server_check') 

617 def testRun_without_produce(self, mock_server_check): 

618 """Test the run method of package alerts with produce set to False and 

619 doWriteAlerts set to true. 

620 """ 

621 packConfig = PackageAlertsConfig(doWriteAlerts=True) 

622 with tempfile.TemporaryDirectory(prefix='alerts') as tempdir: 

623 packConfig.alertWriteLocation = tempdir 

624 packageAlerts = PackageAlertsTask(config=packConfig) 

625 

626 packageAlerts.run(self.diaSources, 

627 self.diaObjects, 

628 self.diaSourceHistory, 

629 self.diaForcedSources, 

630 self.exposure, 

631 self.exposure, 

632 self.exposure) 

633 

634 self.assertEqual(mock_server_check.call_count, 0) 

635 

636 with open(os.path.join(tempdir, f"{VISIT}_{DETECTOR}.avro"), 'rb') as f: 

637 writer_schema, data_stream = \ 

638 packageAlerts.alertSchema.retrieve_alerts(f) 

639 data = list(data_stream) 

640 

641 self.assertEqual(len(data), len(self.diaSources)) 

642 for idx, alert in enumerate(data): 

643 for key, value in alert["diaSource"].items(): 

644 if isinstance(value, float): 

645 if np.isnan(self.diaSources.iloc[idx][key]): 

646 self.assertTrue(np.isnan(value)) 

647 else: 

648 self.assertAlmostEqual( 

649 1 - value / self.diaSources.iloc[idx][key], 

650 0.) 

651 elif value is None: 

652 self.assertTrue(pd.isna(self.diaSources.iloc[idx][key])) 

653 else: 

654 self.assertEqual(value, self.diaSources.iloc[idx][key]) 

655 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], 

656 alert["diaSource"]["dec"], 

657 geom.degrees) 

658 pixelPoint = geom.Point2D(alert["diaSource"]["x"], alert["diaSource"]["y"]) 

659 cutout = self.exposure.getCutout(sphPoint, 

660 geom.Extent2I(self.cutoutSize, 

661 self.cutoutSize)) 

662 ccdCutout = packageAlerts.createCcdDataCutout( 

663 cutout, 

664 sphPoint, 

665 pixelPoint, 

666 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

667 cutout.getPhotoCalib(), 

668 1234, 

669 rotPa=cutout.visitInfo.boresightRotAngle.asDegrees()) 

670 

671 self.assertEqual(alert["cutoutDifference"], 

672 packageAlerts.streamCcdDataToBytes(ccdCutout)) 

673 

674 self.assertEqual(alert["cutoutDifference"], 

675 packageAlerts.streamCcdDataToBytes(ccdCutout)) 

676 

677 @patch.object(PackageAlertsTask, '_server_check') 

678 def testRun_without_produce_use_averagePsf(self, mock_server_check): 

679 """Test the run method of package alerts with produce set to False and 

680 doWriteAlerts set to true. 

681 """ 

682 packConfig = PackageAlertsConfig(doWriteAlerts=True) 

683 with tempfile.TemporaryDirectory(prefix='alerts') as tempdir: 

684 packConfig.alertWriteLocation = tempdir 

685 packConfig.useAveragePsf = True 

686 packageAlerts = PackageAlertsTask(config=packConfig) 

687 

688 packageAlerts.run(self.diaSources, 

689 self.diaObjects, 

690 self.diaSourceHistory, 

691 self.diaForcedSources, 

692 self.exposure, 

693 self.exposure, 

694 self.exposure) 

695 

696 self.assertEqual(mock_server_check.call_count, 0) 

697 

698 with open(os.path.join(tempdir, f"{VISIT}_{DETECTOR}.avro"), 'rb') as f: 

699 writer_schema, data_stream = \ 

700 packageAlerts.alertSchema.retrieve_alerts(f) 

701 data = list(data_stream) 

702 

703 self.assertEqual(len(data), len(self.diaSources)) 

704 for idx, alert in enumerate(data): 

705 for key, value in alert["diaSource"].items(): 

706 if isinstance(value, float): 

707 if np.isnan(self.diaSources.iloc[idx][key]): 

708 self.assertTrue(np.isnan(value)) 

709 else: 

710 self.assertAlmostEqual( 

711 1 - value / self.diaSources.iloc[idx][key], 

712 0.) 

713 elif value is None: 

714 self.assertTrue(pd.isna(self.diaSources.iloc[idx][key])) 

715 else: 

716 self.assertEqual(value, self.diaSources.iloc[idx][key]) 

717 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], 

718 alert["diaSource"]["dec"], 

719 geom.degrees) 

720 pixelPoint = geom.Point2D(alert["diaSource"]["x"], alert["diaSource"]["y"]) 

721 cutout = self.exposure.getCutout(sphPoint, 

722 geom.Extent2I(self.cutoutSize, 

723 self.cutoutSize)) 

724 ccdCutout = packageAlerts.createCcdDataCutout( 

725 cutout, 

726 sphPoint, 

727 pixelPoint, 

728 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

729 cutout.getPhotoCalib(), 

730 1234, 

731 rotPa=cutout.visitInfo.boresightRotAngle.asDegrees()) 

732 self.assertEqual(alert["cutoutDifference"], 

733 packageAlerts.streamCcdDataToBytes(ccdCutout)) 

734 

735 @patch.object(PackageAlertsTask, 'produceAlerts') 

736 @patch('confluent_kafka.Producer') 

737 @patch.object(PackageAlertsTask, '_server_check') 

738 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

739 def testRun_with_produce(self, mock_produceAlerts, mock_server_check, mock_producer): 

740 """Test that packageAlerts calls produceAlerts when doProduceAlerts 

741 is set to True. 

742 """ 

743 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

744 packageAlerts = PackageAlertsTask(config=packConfig) 

745 

746 packageAlerts.run(self.diaSources, 

747 self.diaObjects, 

748 self.diaSourceHistory, 

749 self.diaForcedSources, 

750 self.exposure, 

751 self.exposure, 

752 self.exposure) 

753 self.assertEqual(mock_server_check.call_count, 1) 

754 self.assertEqual(mock_produceAlerts.call_count, 1) 

755 

756 def test_serialize_alert_round_trip(self): 

757 """Test that values in the alert packet exactly round trip. 

758 """ 

759 packClass = PackageAlertsConfig() 

760 packageAlerts = PackageAlertsTask(config=packClass) 

761 

762 alert = mock_ss_alert(1, 1) 

763 serialized = PackageAlertsTask._serializeAlert(packageAlerts, alert) 

764 deserialized = _deserialize_alert(serialized) 

765 for table in ['diaSource', 'ssSource', 'mpc_orbits']: 

766 for field in alert[table]: 

767 self.assertEqual(alert[table][field], deserialized[table][field]) 

768 

769 self.assertEqual(1, deserialized["diaSourceId"]) 

770 

771 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

772 def test_server_check(self): 

773 

774 with self.assertRaisesRegex(KafkaException, "_TRANSPORT"): 

775 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

776 PackageAlertsTask(config=packConfig) 

777 

778 

779class MemoryTester(lsst.utils.tests.MemoryTestCase): 

780 pass 

781 

782 

783def setup_module(module): 

784 lsst.utils.tests.init() 

785 

786 

787if __name__ == "__main__": 787 ↛ 788line 787 didn't jump to line 788 because the condition on line 787 was never true

788 lsst.utils.tests.init() 

789 unittest.main()