Coverage for tests / test_packageAlerts.py: 19%

363 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 09:12 +0000

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import io 

23import os 

24 

25import numpy as np 

26import pandas as pd 

27import tempfile 

28import unittest 

29from unittest.mock import patch, Mock 

30from astropy import wcs 

31from astropy.nddata import CCDData 

32import fastavro 

33try: 

34 import confluent_kafka 

35 from confluent_kafka import KafkaException 

36except ImportError: 

37 confluent_kafka = None 

38 

39import lsst.alert.packet as alertPack 

40from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask 

41from lsst.ap.association.utils import readSchemaFromApdb 

42from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

43import lsst.afw.image as afwImage 

44from lsst.daf.base import DateTime 

45from lsst.dax.apdb import Apdb, ApdbSql 

46import lsst.geom as geom 

47import lsst.meas.base.tests 

48from lsst.sphgeom import Box 

49import lsst.utils.tests 

50from lsst.pipe.tasks.functors import LocalWcs 

51from lsst.pipe.tasks.schemaUtils import convertDataFrameToSdmSchema 

52import utils_tests 

53 

54 

55def _roundTripThroughApdb(objects, sources, forcedSources, dateTime): 

56 """Run object and source catalogs through the Apdb to get the correct 

57 table schemas. 

58 

59 Parameters 

60 ---------- 

61 objects : `pandas.DataFrame` 

62 Set of test DiaObjects to round trip. 

63 sources : `pandas.DataFrame` 

64 Set of test DiaSources to round trip. 

65 forcedSources : `pandas.DataFrame` 

66 Set of test DiaForcedSources to round trip. 

67 dateTime : `astropy.time.Time` 

68 Time for the Apdb. 

69 

70 Returns 

71 ------- 

72 objects : `pandas.DataFrame` 

73 Round tripped objects. 

74 sources : `pandas.DataFrame` 

75 Round tripped sources. 

76 """ 

77 with tempfile.NamedTemporaryFile() as tmpFile: 

78 apdbConfig = ApdbSql.init_database(db_url="sqlite:///" + tmpFile.name) 

79 apdb = Apdb.from_config(apdbConfig) 

80 

81 wholeSky = Box.full() 

82 loadedObjects = apdb.getDiaObjects(wholeSky) 

83 if loadedObjects.empty: 

84 diaObjects = objects 

85 else: 

86 diaObjects = pd.concat([loadedObjects, objects]) 

87 loadedDiaSources = apdb.getDiaSources(wholeSky, [], dateTime) 

88 if loadedDiaSources.empty: 

89 diaSources = sources 

90 else: 

91 diaSources = pd.concat([loadedDiaSources, sources]) 

92 loadedDiaForcedSources = apdb.getDiaForcedSources(wholeSky, [], dateTime) 

93 if loadedDiaForcedSources.empty: 

94 diaForcedSources = forcedSources 

95 else: 

96 diaForcedSources = pd.concat([loadedDiaForcedSources, forcedSources]) 

97 

98 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources) 

99 

100 diaObjects = apdb.getDiaObjects(wholeSky) 

101 diaSources = apdb.getDiaSources(wholeSky, 

102 np.unique(diaObjects["diaObjectId"]), 

103 dateTime) 

104 diaForcedSources = apdb.getDiaForcedSources( 

105 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime) 

106 

107 diaObjects.set_index("diaObjectId", drop=False, inplace=True) 

108 diaSources.set_index(["diaObjectId", "band", "diaSourceId"], 

109 drop=False, 

110 inplace=True) 

111 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True) 

112 

113 # apply SDM type standardization to catch pandas typing issues 

114 schema = readSchemaFromApdb(apdb) 

115 diaObjects = convertDataFrameToSdmSchema(schema, diaObjects, tableName="DiaObject", skipIndex=True) 

116 diaSources = convertDataFrameToSdmSchema(schema, diaSources, tableName="DiaSource", skipIndex=True) 

117 diaForcedSources = convertDataFrameToSdmSchema(schema, diaForcedSources, tableName="DiaForcedSource", 

118 skipIndex=True) 

119 

120 return (diaObjects, diaSources, diaForcedSources) 

121 

122 

123VISIT = 2 

124DETECTOR = 42 

125 

126 

127def mock_alert(dia_source_id): 

128 """Generate a minimal mock alert. 

129 """ 

130 return { 

131 "diaSourceId": dia_source_id, 

132 "diaSource": { 

133 "midpointMjdTai": 5, 

134 "diaSourceId": 1234, 

135 "visit": VISIT, 

136 "detector": DETECTOR, 

137 "band": 'g', 

138 "ra": 12.5, 

139 "dec": -16.9, 

140 # These types are 32-bit floats in the avro schema, so we have to 

141 # make them that type here, so that they round trip appropriately. 

142 "x": np.float32(15.7), 

143 "y": np.float32(89.8), 

144 "apFlux": np.float32(54.85), 

145 "apFluxErr": np.float32(70.0), 

146 "snr": np.float32(6.7), 

147 "psfFlux": np.float32(700.0), 

148 "psfFluxErr": np.float32(90.0), 

149 # unlike in transformDiaSourceCatalog.py we need a timezone-aware 

150 # version because mock_alert does not go through pandas 

151 "timeProcessedMjdTai": DateTime.now().get(system=DateTime.MJD, scale=DateTime.TAI) 

152 } 

153 } 

154 

155 

156def mock_ss_alert(dia_source_id, ss_object_id): 

157 """Generate a minimal mock alert. 

158 """ 

159 alert = mock_alert(dia_source_id) 

160 alert['mpc_orbits'] = { 

161 'id': 42, 

162 'designation': '2020 AH11', # a string-typed field 

163 'q': np.float64(0.99999), # a double-typed field 

164 'unpacked_primary_provisional_designation': '2020 AH11', 

165 'packed_primary_provisional_designation': 'J201AH1', 

166 

167 } 

168 alert['ssSource'] = { 

169 'diaSourceId': dia_source_id, 

170 'ssObjectId': ss_object_id, 

171 'eclLambda': np.float64(3.141592), # a double-typed field 

172 'eclBeta': np.float64(3.141592), # a double-typed field 

173 'galLon': np.float64(3.141592), # a double-typed field 

174 'galLat': np.float64(3.141592), # a double-typed field 

175 'helioRange': np.float32(3.141592), # a float-typed field 

176 } 

177 return alert 

178 

179 

180def _deserialize_alert(alert_bytes): 

181 """Deserialize an alert message from Kafka. 

182 

183 Parameters 

184 ---------- 

185 alert_bytes : `bytes` 

186 Binary-encoding serialized Avro alert, including Confluent Wire 

187 Format prefix. 

188 

189 Returns 

190 ------- 

191 alert : `dict` 

192 An alert payload. 

193 """ 

194 schema = alertPack.Schema.from_uri(str(alertPack.get_uri_to_latest_schema())) 

195 content_bytes = io.BytesIO(alert_bytes[5:]) 

196 

197 return fastavro.schemaless_reader(content_bytes, schema.definition) 

198 

199 

200class TestPackageAlerts(lsst.utils.tests.TestCase): 

201 def setUp(self): 

202 # Create an instance of random generator with fixed seed. 

203 rng = np.random.default_rng(1234) 

204 

205 patcher = patch.dict(os.environ, {"AP_KAFKA_PRODUCER_PASSWORD": "fake_password", 

206 "AP_KAFKA_PRODUCER_USERNAME": "fake_username", 

207 "AP_KAFKA_SERVER": "fake_server", 

208 "AP_KAFKA_TOPIC": "fake_topic"}) 

209 self.environ = patcher.start() 

210 self.addCleanup(patcher.stop) 

211 self.cutoutSize = 35 

212 self.center = lsst.geom.Point2D(50.1, 49.8) 

213 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30), 

214 lsst.geom.Extent2I(140, 160)) 

215 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox) 

216 self.dataset.addSource(100000.0, self.center) 

217 exposure, catalog = self.dataset.realize( 

218 10.0, 

219 self.dataset.makeMinimalSchema(), 

220 randomSeed=1234) 

221 self.exposure = exposure 

222 detector = DetectorWrapper(id=DETECTOR, bbox=exposure.getBBox()).detector 

223 self.exposure.setDetector(detector) 

224 

225 visit = afwImage.VisitInfo( 

226 id=VISIT, 

227 exposureTime=200., 

228 date=DateTime("2014-05-13T17:00:00.000000000", 

229 DateTime.Timescale.TAI), 

230 boresightRotAngle=geom.Angle(0.785398)) 

231 self.exposure.info.id = 1234 

232 self.exposure.info.setVisitInfo(visit) 

233 

234 self.exposure.setFilter( 

235 afwImage.FilterLabel(band='g', physical="g.MP9401")) 

236 

237 diaObjects = utils_tests.makeDiaObjects(2, self.exposure, rng) 

238 diaSourceHistory = utils_tests.makeDiaSources( 

239 10, diaObjects["diaObjectId"].to_numpy(), self.exposure, rng) 

240 diaForcedSources = utils_tests.makeDiaForcedSources( 

241 10, diaObjects["diaObjectId"].to_numpy(), self.exposure, rng) 

242 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb( 

243 diaObjects, 

244 diaSourceHistory, 

245 diaForcedSources, 

246 self.exposure.visitInfo.date.toAstropy()) 

247 diaSourceHistory["programId"] = 0 

248 

249 self.diaSources = diaSourceHistory.loc[[(1, "g", 9), (2, "g", 10)], :] 

250 self.diaSources["bboxSize"] = self.cutoutSize 

251 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9), 

252 (2, "g", 10)]) 

253 

254 self.cutoutWcs = wcs.WCS(naxis=2) 

255 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]] 

256 self.cutoutWcs.wcs.crval = [ 

257 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(), 

258 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()] 

259 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix() 

260 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] 

261 

262 def testCreateExtentMinimum(self): 

263 """Test the extent creation for the cutout bbox returns a cutout with 

264 the minimum cutouut size. 

265 """ 

266 packConfig = PackageAlertsConfig() 

267 # Just create a minimum less than the default cutout size. 

268 packConfig.minCutoutSize = self.cutoutSize - 5 

269 packageAlerts = PackageAlertsTask(config=packConfig) 

270 extent = packageAlerts.createDiaSourceExtent( 

271 packConfig.minCutoutSize - 5) 

272 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize, 

273 packConfig.minCutoutSize)) 

274 # Test that the cutout size is correctly increased. 

275 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize) 

276 self.assertTrue(extent == geom.Extent2I(self.cutoutSize, 

277 self.cutoutSize)) 

278 

279 def testCreateExtentMaximum(self): 

280 """Test the extent creation for the cutout bbox returns a cutout with 

281 the maximum cutout size. 

282 """ 

283 packConfig = PackageAlertsConfig() 

284 # Just create a maximum more than the default cutout size. 

285 packConfig.maxCutoutSize = self.cutoutSize + 5 

286 packageAlerts = PackageAlertsTask(config=packConfig) 

287 extent = packageAlerts.createDiaSourceExtent( 

288 packConfig.maxCutoutSize + 5) 

289 self.assertTrue(extent == geom.Extent2I(packConfig.maxCutoutSize, 

290 packConfig.maxCutoutSize)) 

291 # Test that the cutout size is correctly reduced. 

292 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize) 

293 self.assertTrue(extent == geom.Extent2I(self.cutoutSize, 

294 self.cutoutSize)) 

295 

296 def testCreateCcdDataCutout(self): 

297 """Test that the data is being extracted into the CCDData cutout 

298 correctly. 

299 """ 

300 packageAlerts = PackageAlertsTask() 

301 

302 diaSrcId = 1234 

303 ccdData = packageAlerts.createCcdDataCutout( 

304 self.exposure, 

305 self.exposure.getWcs().getSkyOrigin(), 

306 self.exposure.getWcs().getPixelOrigin(), 

307 self.exposure.getBBox().getDimensions(), 

308 self.exposure.getPhotoCalib(), 

309 diaSrcId) 

310 calibExposure = self.exposure.getPhotoCalib().calibrateImage( 

311 self.exposure.getMaskedImage()) 

312 

313 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd, 

314 self.cutoutWcs.wcs.cd) 

315 self.assertFloatsAlmostEqual(ccdData.data, 

316 calibExposure.getImage().array) 

317 self.assertFloatsAlmostEqual(ccdData.psf, 

318 self.exposure.psf.computeKernelImage(self.center).array) 

319 

320 ccdData = packageAlerts.createCcdDataCutout( 

321 self.exposure, 

322 geom.SpherePoint(0, 0, geom.degrees), 

323 geom.Point2D(-100, -100), 

324 self.exposure.getBBox().getDimensions(), 

325 self.exposure.getPhotoCalib(), 

326 diaSrcId) 

327 self.assertTrue(ccdData is None) 

328 

329 def testMakeLocalTransformMatrix(self): 

330 """Test that the local WCS approximation is correct. 

331 """ 

332 packageAlerts = PackageAlertsTask() 

333 

334 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

335 cutout = self.exposure.getCutout(sphPoint, 

336 geom.Extent2I(self.cutoutSize, 

337 self.cutoutSize)) 

338 cd = packageAlerts.makeLocalTransformMatrix( 

339 cutout.getWcs(), self.center, sphPoint) 

340 self.assertFloatsAlmostEqual( 

341 cd, 

342 cutout.getWcs().getCdMatrix(), 

343 rtol=5e-10, 

344 atol=5e-10) 

345 

346 def testStreamCcdDataToBytes(self): 

347 """Test round tripping an CCDData cutout to bytes and back. 

348 """ 

349 packageAlerts = PackageAlertsTask() 

350 

351 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

352 cutout = self.exposure.getCutout(sphPoint, 

353 geom.Extent2I(self.cutoutSize, 

354 self.cutoutSize)) 

355 cutoutCcdData = CCDData( 

356 data=cutout.getImage().array, 

357 wcs=self.cutoutWcs, 

358 unit="adu") 

359 

360 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData) 

361 with io.BytesIO(cutoutBytes) as bytesIO: 

362 cutoutFromBytes = CCDData.read(bytesIO, format="fits") 

363 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data) 

364 

365 def testMakeAlertDict(self): 

366 """Test stripping data from the various data products and into a 

367 dictionary "alert". 

368 """ 

369 packageAlerts = PackageAlertsTask() 

370 dia_source_id = 1234 

371 

372 for srcIdx, diaSource in self.diaSources.iterrows(): 

373 sphPoint = geom.SpherePoint(diaSource["ra"], 

374 diaSource["dec"], 

375 geom.degrees) 

376 pixelPoint = geom.Point2D(diaSource["x"], diaSource["y"]) 

377 cutout = self.exposure.getCutout(sphPoint, 

378 geom.Extent2I(self.cutoutSize, 

379 self.cutoutSize)) 

380 ccdCutout = packageAlerts.createCcdDataCutout( 

381 cutout, 

382 sphPoint, 

383 pixelPoint, 

384 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

385 cutout.getPhotoCalib(), 

386 1234, 

387 rotPa=cutout.visitInfo.boresightRotAngle.asDegrees()) 

388 cutoutBytes = packageAlerts.streamCcdDataToBytes( 

389 ccdCutout) 

390 objSources = self.diaSourceHistory.loc[srcIdx[0]] 

391 objForcedSources = self.diaForcedSources.loc[srcIdx[0]] 

392 alert = packageAlerts.makeAlertDict( 

393 dia_source_id, 

394 self.exposure.visitInfo.getObservationReason(), 

395 self.exposure.visitInfo.getObject(), 

396 diaSource, 

397 self.diaObjects.loc[srcIdx[0]], 

398 objSources, 

399 objForcedSources, 

400 ccdCutout, 

401 ccdCutout, 

402 ccdCutout) 

403 self.assertEqual(len(alert), 13) 

404 

405 self.assertEqual(alert["diaSourceId"], dia_source_id) 

406 self.assertEqual(alert["diaSource"], diaSource.to_dict()) 

407 self.assertIsNone(alert["observation_reason"]) 

408 self.assertIsNone(alert["target_name"]) 

409 self.assertEqual(alert["cutoutDifference"], 

410 cutoutBytes) 

411 self.assertEqual(alert["cutoutScience"], 

412 cutoutBytes) 

413 self.assertEqual(alert["cutoutTemplate"], 

414 cutoutBytes) 

415 science_cutout = CCDData.read(io.BytesIO(alert["cutoutScience"]), 

416 format="fits") 

417 template_cutout = CCDData.read(io.BytesIO(alert["cutoutTemplate"]), 

418 format="fits") 

419 self.assertAlmostEqual(science_cutout.header["ROTPA"], 

420 template_cutout.header["ROTPA"]) 

421 

422 def testMakeAlertDictSchedulerFields(self): 

423 """Test non-null scheduler fields pass through as expected. 

424 

425 """ 

426 packageAlerts = PackageAlertsTask() 

427 dia_source_id = 1234 

428 

429 for srcIdx, diaSource in self.diaSources.iterrows(): 

430 objSources = self.diaSourceHistory.loc[srcIdx[0]] 

431 objForcedSources = self.diaForcedSources.loc[srcIdx[0]] 

432 obs_reason = f"obs_reason_{srcIdx}", 

433 target = f"target_name_{srcIdx}", 

434 alert = packageAlerts.makeAlertDict( 

435 dia_source_id, 

436 obs_reason, 

437 target, 

438 diaSource, 

439 self.diaObjects.loc[srcIdx[0]], 

440 objSources, 

441 objForcedSources, 

442 None, 

443 None, 

444 None) 

445 self.assertEqual(len(alert), 13) 

446 

447 self.assertEqual(alert["observation_reason"], obs_reason) 

448 self.assertEqual(alert["target_name"], target) 

449 

450 def testCutoutRotpa(self): 

451 """Test that the ROTPA header keyword matches the boresightRotAngle from visitInfo. 

452 """ 

453 packageAlerts = PackageAlertsTask() 

454 

455 # Create a cutout using existing test exposure 

456 sphPoint = self.exposure.getWcs().pixelToSky(self.center) 

457 cutout = self.exposure.getCutout(sphPoint, 

458 geom.Extent2I(self.cutoutSize, 

459 self.cutoutSize)) 

460 

461 # Create CCDData cutout 

462 ccdCutout = packageAlerts.createCcdDataCutout( 

463 cutout, 

464 sphPoint, 

465 self.center, 

466 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

467 cutout.getPhotoCalib(), 

468 1234, 

469 rotPa=cutout.visitInfo.boresightRotAngle.asDegrees()) 

470 

471 wcs = self.exposure.wcs 

472 cd_matrix = wcs.getCdMatrix() # This gets the CD matrix elements 

473 angle_rad = LocalWcs.computePositionAngle([0], 

474 [cd_matrix[0, 0]], # CD1_1 

475 [cd_matrix[0, 1]], # CD1_2 

476 [cd_matrix[1, 0]], # CD2_1 

477 [cd_matrix[1, 1]] # CD2_2 

478 ) 

479 wcs_pa = geom.Angle(angle_rad.iloc[0], geom.radians).asDegrees() 

480 

481 cutoutBytes = packageAlerts.streamCcdDataToBytes(ccdCutout) 

482 with io.BytesIO(cutoutBytes) as bytesIO: 

483 cutoutFromBytes = CCDData.read(bytesIO, format="fits") 

484 

485 self.assertIn('ROTPA', cutoutFromBytes.header) 

486 self.assertAlmostEqual( 

487 cutoutFromBytes.header['ROTPA'], 

488 self.exposure.visitInfo.boresightRotAngle.asDegrees()) 

489 self.assertAlmostEqual( 

490 cutoutFromBytes.header['ROTPA'], 

491 wcs_pa, places=4) 

492 

493 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

494 def test_produceAlerts_empty_password(self): 

495 """ Test that produceAlerts raises if the password is empty or missing. 

496 """ 

497 self.environ['AP_KAFKA_PRODUCER_PASSWORD'] = "" 

498 with self.assertRaisesRegex(ValueError, "Kafka password"): 

499 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

500 PackageAlertsTask(config=packConfig) 

501 

502 del self.environ['AP_KAFKA_PRODUCER_PASSWORD'] 

503 with self.assertRaisesRegex(ValueError, "Kafka password"): 

504 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

505 PackageAlertsTask(config=packConfig) 

506 

507 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

508 def test_produceAlerts_empty_username(self): 

509 """ Test that produceAlerts raises if the username is empty or missing. 

510 """ 

511 self.environ['AP_KAFKA_PRODUCER_USERNAME'] = "" 

512 with self.assertRaisesRegex(ValueError, "Kafka username"): 

513 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

514 PackageAlertsTask(config=packConfig) 

515 

516 del self.environ['AP_KAFKA_PRODUCER_USERNAME'] 

517 with self.assertRaisesRegex(ValueError, "Kafka username"): 

518 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

519 PackageAlertsTask(config=packConfig) 

520 

521 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

522 def test_produceAlerts_empty_server(self): 

523 """ Test that produceAlerts raises if the server is empty or missing. 

524 """ 

525 self.environ['AP_KAFKA_SERVER'] = "" 

526 with self.assertRaisesRegex(ValueError, "Kafka server"): 

527 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

528 PackageAlertsTask(config=packConfig) 

529 

530 del self.environ['AP_KAFKA_SERVER'] 

531 with self.assertRaisesRegex(ValueError, "Kafka server"): 

532 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

533 PackageAlertsTask(config=packConfig) 

534 

535 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

536 def test_produceAlerts_empty_topic(self): 

537 """ Test that produceAlerts raises if the topic is empty or missing. 

538 """ 

539 self.environ['AP_KAFKA_TOPIC'] = "" 

540 with self.assertRaisesRegex(ValueError, "Kafka topic"): 

541 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

542 PackageAlertsTask(config=packConfig) 

543 

544 del self.environ['AP_KAFKA_TOPIC'] 

545 with self.assertRaisesRegex(ValueError, "Kafka topic"): 

546 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

547 PackageAlertsTask(config=packConfig) 

548 

549 @patch('confluent_kafka.Producer') 

550 @patch.object(PackageAlertsTask, '_server_check') 

551 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

552 def test_produceAlerts_success(self, mock_server_check, mock_producer): 

553 """ Test that produceAlerts calls the producer on all provided alerts 

554 when the alerts are all under the batch size limit. 

555 """ 

556 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

557 packageAlerts = PackageAlertsTask(config=packConfig) 

558 alerts = [mock_alert(1), mock_alert(2), mock_ss_alert(3, 3)] 

559 

560 # Create a variable and assign it an instance of the patched kafka producer 

561 producer_instance = mock_producer.return_value 

562 producer_instance.produce = Mock() 

563 producer_instance.flush = Mock() 

564 unix_midpoint = self.exposure.visitInfo.date.toAstropy().tai.unix 

565 exposure_time = self.exposure.visitInfo.exposureTime 

566 packageAlerts.produceAlerts(alerts, VISIT, DETECTOR, unix_midpoint, exposure_time) 

567 

568 self.assertEqual(mock_server_check.call_count, 1) 

569 self.assertEqual(producer_instance.produce.call_count, len(alerts)) 

570 self.assertEqual(producer_instance.flush.call_count, len(alerts)+1) 

571 

572 @patch('confluent_kafka.Producer') 

573 @patch.object(PackageAlertsTask, '_server_check') 

574 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

575 def test_produceAlerts_one_failure(self, mock_server_check, mock_producer): 

576 """ Test that produceAlerts correctly fails on one alert 

577 and is writing the failure to disk. 

578 """ 

579 counter = 0 

580 

581 def mock_produce(*args, **kwargs): 

582 nonlocal counter 

583 counter += 1 

584 if counter == 2: 

585 raise KafkaException 

586 else: 

587 return 

588 

589 packConfig = PackageAlertsConfig(doProduceAlerts=True, doWriteFailedAlerts=True) 

590 packageAlerts = PackageAlertsTask(config=packConfig) 

591 

592 patcher = patch("builtins.open") 

593 patch_open = patcher.start() 

594 alerts = [mock_alert(1), mock_alert(2), mock_alert(3), mock_ss_alert(4, 4)] 

595 unix_midpoint = self.exposure.visitInfo.date.toAstropy().tai.unix 

596 exposure_time = self.exposure.visitInfo.exposureTime 

597 

598 producer_instance = mock_producer.return_value 

599 producer_instance.produce = Mock(side_effect=mock_produce) 

600 producer_instance.flush = Mock() 

601 packageAlerts.produceAlerts(alerts, VISIT, DETECTOR, unix_midpoint, exposure_time) 

602 

603 self.assertEqual(mock_server_check.call_count, 1) 

604 self.assertEqual(producer_instance.produce.call_count, len(alerts)) 

605 self.assertEqual(patch_open.call_count, 1) 

606 self.assertIn(f"{VISIT}_{DETECTOR}_2.avro", patch_open.call_args.args[0]) 

607 # Because one produce raises, we call flush one fewer times than in the success 

608 # test above. 

609 self.assertEqual(producer_instance.flush.call_count, len(alerts)) 

610 patcher.stop() 

611 

612 @patch.object(PackageAlertsTask, '_server_check') 

613 def testRun_without_produce(self, mock_server_check): 

614 """Test the run method of package alerts with produce set to False and 

615 doWriteAlerts set to true. 

616 """ 

617 packConfig = PackageAlertsConfig(doWriteAlerts=True) 

618 with tempfile.TemporaryDirectory(prefix='alerts') as tempdir: 

619 packConfig.alertWriteLocation = tempdir 

620 packageAlerts = PackageAlertsTask(config=packConfig) 

621 

622 packageAlerts.run(self.diaSources, 

623 self.diaObjects, 

624 self.diaSourceHistory, 

625 self.diaForcedSources, 

626 self.exposure, 

627 self.exposure, 

628 self.exposure) 

629 

630 self.assertEqual(mock_server_check.call_count, 0) 

631 

632 with open(os.path.join(tempdir, f"{VISIT}_{DETECTOR}.avro"), 'rb') as f: 

633 writer_schema, data_stream = \ 

634 packageAlerts.alertSchema.retrieve_alerts(f) 

635 data = list(data_stream) 

636 

637 self.assertEqual(len(data), len(self.diaSources)) 

638 for idx, alert in enumerate(data): 

639 for key, value in alert["diaSource"].items(): 

640 if isinstance(value, float): 

641 if np.isnan(self.diaSources.iloc[idx][key]): 

642 self.assertTrue(np.isnan(value)) 

643 else: 

644 self.assertAlmostEqual( 

645 1 - value / self.diaSources.iloc[idx][key], 

646 0.) 

647 elif value is None: 

648 self.assertTrue(pd.isna(self.diaSources.iloc[idx][key])) 

649 else: 

650 self.assertEqual(value, self.diaSources.iloc[idx][key]) 

651 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], 

652 alert["diaSource"]["dec"], 

653 geom.degrees) 

654 pixelPoint = geom.Point2D(alert["diaSource"]["x"], alert["diaSource"]["y"]) 

655 cutout = self.exposure.getCutout(sphPoint, 

656 geom.Extent2I(self.cutoutSize, 

657 self.cutoutSize)) 

658 ccdCutout = packageAlerts.createCcdDataCutout( 

659 cutout, 

660 sphPoint, 

661 pixelPoint, 

662 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

663 cutout.getPhotoCalib(), 

664 1234, 

665 rotPa=cutout.visitInfo.boresightRotAngle.asDegrees()) 

666 

667 self.assertEqual(alert["cutoutDifference"], 

668 packageAlerts.streamCcdDataToBytes(ccdCutout)) 

669 

670 self.assertEqual(alert["cutoutDifference"], 

671 packageAlerts.streamCcdDataToBytes(ccdCutout)) 

672 

673 @patch.object(PackageAlertsTask, '_server_check') 

674 def testRun_without_produce_use_averagePsf(self, mock_server_check): 

675 """Test the run method of package alerts with produce set to False and 

676 doWriteAlerts set to true. 

677 """ 

678 packConfig = PackageAlertsConfig(doWriteAlerts=True) 

679 with tempfile.TemporaryDirectory(prefix='alerts') as tempdir: 

680 packConfig.alertWriteLocation = tempdir 

681 packConfig.useAveragePsf = True 

682 packageAlerts = PackageAlertsTask(config=packConfig) 

683 

684 packageAlerts.run(self.diaSources, 

685 self.diaObjects, 

686 self.diaSourceHistory, 

687 self.diaForcedSources, 

688 self.exposure, 

689 self.exposure, 

690 self.exposure) 

691 

692 self.assertEqual(mock_server_check.call_count, 0) 

693 

694 with open(os.path.join(tempdir, f"{VISIT}_{DETECTOR}.avro"), 'rb') as f: 

695 writer_schema, data_stream = \ 

696 packageAlerts.alertSchema.retrieve_alerts(f) 

697 data = list(data_stream) 

698 

699 self.assertEqual(len(data), len(self.diaSources)) 

700 for idx, alert in enumerate(data): 

701 for key, value in alert["diaSource"].items(): 

702 if isinstance(value, float): 

703 if np.isnan(self.diaSources.iloc[idx][key]): 

704 self.assertTrue(np.isnan(value)) 

705 else: 

706 self.assertAlmostEqual( 

707 1 - value / self.diaSources.iloc[idx][key], 

708 0.) 

709 elif value is None: 

710 self.assertTrue(pd.isna(self.diaSources.iloc[idx][key])) 

711 else: 

712 self.assertEqual(value, self.diaSources.iloc[idx][key]) 

713 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], 

714 alert["diaSource"]["dec"], 

715 geom.degrees) 

716 pixelPoint = geom.Point2D(alert["diaSource"]["x"], alert["diaSource"]["y"]) 

717 cutout = self.exposure.getCutout(sphPoint, 

718 geom.Extent2I(self.cutoutSize, 

719 self.cutoutSize)) 

720 ccdCutout = packageAlerts.createCcdDataCutout( 

721 cutout, 

722 sphPoint, 

723 pixelPoint, 

724 geom.Extent2I(self.cutoutSize, self.cutoutSize), 

725 cutout.getPhotoCalib(), 

726 1234, 

727 rotPa=cutout.visitInfo.boresightRotAngle.asDegrees()) 

728 self.assertEqual(alert["cutoutDifference"], 

729 packageAlerts.streamCcdDataToBytes(ccdCutout)) 

730 

731 @patch.object(PackageAlertsTask, 'produceAlerts') 

732 @patch('confluent_kafka.Producer') 

733 @patch.object(PackageAlertsTask, '_server_check') 

734 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

735 def testRun_with_produce(self, mock_produceAlerts, mock_server_check, mock_producer): 

736 """Test that packageAlerts calls produceAlerts when doProduceAlerts 

737 is set to True. 

738 """ 

739 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

740 packageAlerts = PackageAlertsTask(config=packConfig) 

741 

742 packageAlerts.run(self.diaSources, 

743 self.diaObjects, 

744 self.diaSourceHistory, 

745 self.diaForcedSources, 

746 self.exposure, 

747 self.exposure, 

748 self.exposure) 

749 self.assertEqual(mock_server_check.call_count, 1) 

750 self.assertEqual(mock_produceAlerts.call_count, 1) 

751 

752 def test_serialize_alert_round_trip(self): 

753 """Test that values in the alert packet exactly round trip. 

754 """ 

755 packClass = PackageAlertsConfig() 

756 packageAlerts = PackageAlertsTask(config=packClass) 

757 

758 alert = mock_ss_alert(1, 1) 

759 serialized = PackageAlertsTask._serializeAlert(packageAlerts, alert) 

760 deserialized = _deserialize_alert(serialized) 

761 for table in ['diaSource', 'ssSource', 'mpc_orbits']: 

762 for field in alert[table]: 

763 self.assertEqual(alert[table][field], deserialized[table][field]) 

764 

765 self.assertEqual(1, deserialized["diaSourceId"]) 

766 

767 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') 

768 def test_server_check(self): 

769 

770 with self.assertRaisesRegex(KafkaException, "_TRANSPORT"): 

771 packConfig = PackageAlertsConfig(doProduceAlerts=True) 

772 PackageAlertsTask(config=packConfig) 

773 

774 

775class MemoryTester(lsst.utils.tests.MemoryTestCase): 

776 pass 

777 

778 

779def setup_module(module): 

780 lsst.utils.tests.init() 

781 

782 

783if __name__ == "__main__": 783 ↛ 784line 783 didn't jump to line 784 because the condition on line 783 was never true

784 lsst.utils.tests.init() 

785 unittest.main()