Coverage for tests / test_packageAlerts.py: 19%
363 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-24 08:39 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-24 08:39 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import io
23import os
25import numpy as np
26import pandas as pd
27import tempfile
28import unittest
29from unittest.mock import patch, Mock
30from astropy import wcs
31from astropy.nddata import CCDData
32import fastavro
33try:
34 import confluent_kafka
35 from confluent_kafka import KafkaException
36except ImportError:
37 confluent_kafka = None
39import lsst.alert.packet as alertPack
40from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask
41from lsst.ap.association.utils import readSchemaFromApdb
42from lsst.afw.cameraGeom.testUtils import DetectorWrapper
43import lsst.afw.image as afwImage
44from lsst.daf.base import DateTime
45from lsst.dax.apdb import Apdb, ApdbSql
46import lsst.geom as geom
47import lsst.meas.base.tests
48from lsst.sphgeom import Box
49import lsst.utils.tests
50from lsst.pipe.tasks.functors import LocalWcs
51from lsst.pipe.tasks.schemaUtils import convertDataFrameToSdmSchema
52import utils_tests
55def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
56 """Run object and source catalogs through the Apdb to get the correct
57 table schemas.
59 Parameters
60 ----------
61 objects : `pandas.DataFrame`
62 Set of test DiaObjects to round trip.
63 sources : `pandas.DataFrame`
64 Set of test DiaSources to round trip.
65 forcedSources : `pandas.DataFrame`
66 Set of test DiaForcedSources to round trip.
67 dateTime : `astropy.time.Time`
68 Time for the Apdb.
70 Returns
71 -------
72 objects : `pandas.DataFrame`
73 Round tripped objects.
74 sources : `pandas.DataFrame`
75 Round tripped sources.
76 """
77 with tempfile.NamedTemporaryFile() as tmpFile:
78 apdbConfig = ApdbSql.init_database(db_url="sqlite:///" + tmpFile.name)
79 apdb = Apdb.from_config(apdbConfig)
81 wholeSky = Box.full()
82 loadedObjects = apdb.getDiaObjects(wholeSky)
83 if loadedObjects.empty:
84 diaObjects = objects
85 else:
86 diaObjects = pd.concat([loadedObjects, objects])
87 loadedDiaSources = apdb.getDiaSources(wholeSky, [], dateTime)
88 if loadedDiaSources.empty:
89 diaSources = sources
90 else:
91 diaSources = pd.concat([loadedDiaSources, sources])
92 loadedDiaForcedSources = apdb.getDiaForcedSources(wholeSky, [], dateTime)
93 if loadedDiaForcedSources.empty:
94 diaForcedSources = forcedSources
95 else:
96 diaForcedSources = pd.concat([loadedDiaForcedSources, forcedSources])
98 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources)
100 diaObjects = apdb.getDiaObjects(wholeSky)
101 diaSources = apdb.getDiaSources(wholeSky,
102 np.unique(diaObjects["diaObjectId"]),
103 dateTime)
104 diaForcedSources = apdb.getDiaForcedSources(
105 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime)
107 diaObjects.set_index("diaObjectId", drop=False, inplace=True)
108 diaSources.set_index(["diaObjectId", "band", "diaSourceId"],
109 drop=False,
110 inplace=True)
111 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True)
113 # apply SDM type standardization to catch pandas typing issues
114 schema = readSchemaFromApdb(apdb)
115 diaObjects = convertDataFrameToSdmSchema(schema, diaObjects, tableName="DiaObject", skipIndex=True)
116 diaSources = convertDataFrameToSdmSchema(schema, diaSources, tableName="DiaSource", skipIndex=True)
117 diaForcedSources = convertDataFrameToSdmSchema(schema, diaForcedSources, tableName="DiaForcedSource",
118 skipIndex=True)
120 return (diaObjects, diaSources, diaForcedSources)
123VISIT = 2
124DETECTOR = 42
127def mock_alert(dia_source_id):
128 """Generate a minimal mock alert.
129 """
130 return {
131 "diaSourceId": dia_source_id,
132 "diaSource": {
133 "midpointMjdTai": 5,
134 "diaSourceId": 1234,
135 "visit": VISIT,
136 "detector": DETECTOR,
137 "band": 'g',
138 "ra": 12.5,
139 "dec": -16.9,
140 # These types are 32-bit floats in the avro schema, so we have to
141 # make them that type here, so that they round trip appropriately.
142 "x": np.float32(15.7),
143 "y": np.float32(89.8),
144 "apFlux": np.float32(54.85),
145 "apFluxErr": np.float32(70.0),
146 "snr": np.float32(6.7),
147 "psfFlux": np.float32(700.0),
148 "psfFluxErr": np.float32(90.0),
149 # unlike in transformDiaSourceCatalog.py we need a timezone-aware
150 # version because mock_alert does not go through pandas
151 "timeProcessedMjdTai": DateTime.now().get(system=DateTime.MJD, scale=DateTime.TAI)
152 }
153 }
156def mock_ss_alert(dia_source_id, ss_object_id):
157 """Generate a minimal mock alert.
158 """
159 alert = mock_alert(dia_source_id)
160 alert['mpc_orbits'] = {
161 'id': 42,
162 'designation': '2020 AH11', # a string-typed field
163 'q': np.float64(0.99999), # a double-typed field
164 'unpacked_primary_provisional_designation': '2020 AH11',
165 'packed_primary_provisional_designation': 'J201AH1',
167 }
168 alert['ssSource'] = {
169 'diaSourceId': dia_source_id,
170 'ssObjectId': ss_object_id,
171 'eclLambda': np.float64(3.141592), # a double-typed field
172 'eclBeta': np.float64(3.141592), # a double-typed field
173 'galLon': np.float64(3.141592), # a double-typed field
174 'galLat': np.float64(3.141592), # a double-typed field
175 'helioRange': np.float32(3.141592), # a float-typed field
176 }
177 return alert
180def _deserialize_alert(alert_bytes):
181 """Deserialize an alert message from Kafka.
183 Parameters
184 ----------
185 alert_bytes : `bytes`
186 Binary-encoding serialized Avro alert, including Confluent Wire
187 Format prefix.
189 Returns
190 -------
191 alert : `dict`
192 An alert payload.
193 """
194 schema = alertPack.Schema.from_uri(str(alertPack.get_uri_to_latest_schema()))
195 content_bytes = io.BytesIO(alert_bytes[5:])
197 return fastavro.schemaless_reader(content_bytes, schema.definition)
200class TestPackageAlerts(lsst.utils.tests.TestCase):
201 def setUp(self):
202 # Create an instance of random generator with fixed seed.
203 rng = np.random.default_rng(1234)
205 patcher = patch.dict(os.environ, {"AP_KAFKA_PRODUCER_PASSWORD": "fake_password",
206 "AP_KAFKA_PRODUCER_USERNAME": "fake_username",
207 "AP_KAFKA_SERVER": "fake_server",
208 "AP_KAFKA_TOPIC": "fake_topic"})
209 self.environ = patcher.start()
210 self.addCleanup(patcher.stop)
211 self.cutoutSize = 35
212 self.center = lsst.geom.Point2D(50.1, 49.8)
213 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
214 lsst.geom.Extent2I(140, 160))
215 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
216 self.dataset.addSource(100000.0, self.center)
217 exposure, catalog = self.dataset.realize(
218 10.0,
219 self.dataset.makeMinimalSchema(),
220 randomSeed=1234)
221 self.exposure = exposure
222 detector = DetectorWrapper(id=DETECTOR, bbox=exposure.getBBox()).detector
223 self.exposure.setDetector(detector)
225 visit = afwImage.VisitInfo(
226 id=VISIT,
227 exposureTime=200.,
228 date=DateTime("2014-05-13T17:00:00.000000000",
229 DateTime.Timescale.TAI),
230 boresightRotAngle=geom.Angle(0.785398))
231 self.exposure.info.id = 1234
232 self.exposure.info.setVisitInfo(visit)
234 self.exposure.setFilter(
235 afwImage.FilterLabel(band='g', physical="g.MP9401"))
237 diaObjects = utils_tests.makeDiaObjects(2, self.exposure, rng)
238 diaSourceHistory = utils_tests.makeDiaSources(
239 10, diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)
240 diaForcedSources = utils_tests.makeDiaForcedSources(
241 10, diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)
242 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb(
243 diaObjects,
244 diaSourceHistory,
245 diaForcedSources,
246 self.exposure.visitInfo.date.toAstropy())
247 diaSourceHistory["programId"] = 0
249 self.diaSources = diaSourceHistory.loc[[(1, "g", 9), (2, "g", 10)], :]
250 self.diaSources["bboxSize"] = self.cutoutSize
251 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9),
252 (2, "g", 10)])
254 self.cutoutWcs = wcs.WCS(naxis=2)
255 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]]
256 self.cutoutWcs.wcs.crval = [
257 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(),
258 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()]
259 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
260 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
262 def testCreateExtentMinimum(self):
263 """Test the extent creation for the cutout bbox returns a cutout with
264 the minimum cutouut size.
265 """
266 packConfig = PackageAlertsConfig()
267 # Just create a minimum less than the default cutout size.
268 packConfig.minCutoutSize = self.cutoutSize - 5
269 packageAlerts = PackageAlertsTask(config=packConfig)
270 extent = packageAlerts.createDiaSourceExtent(
271 packConfig.minCutoutSize - 5)
272 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize,
273 packConfig.minCutoutSize))
274 # Test that the cutout size is correctly increased.
275 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize)
276 self.assertTrue(extent == geom.Extent2I(self.cutoutSize,
277 self.cutoutSize))
279 def testCreateExtentMaximum(self):
280 """Test the extent creation for the cutout bbox returns a cutout with
281 the maximum cutout size.
282 """
283 packConfig = PackageAlertsConfig()
284 # Just create a maximum more than the default cutout size.
285 packConfig.maxCutoutSize = self.cutoutSize + 5
286 packageAlerts = PackageAlertsTask(config=packConfig)
287 extent = packageAlerts.createDiaSourceExtent(
288 packConfig.maxCutoutSize + 5)
289 self.assertTrue(extent == geom.Extent2I(packConfig.maxCutoutSize,
290 packConfig.maxCutoutSize))
291 # Test that the cutout size is correctly reduced.
292 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize)
293 self.assertTrue(extent == geom.Extent2I(self.cutoutSize,
294 self.cutoutSize))
296 def testCreateCcdDataCutout(self):
297 """Test that the data is being extracted into the CCDData cutout
298 correctly.
299 """
300 packageAlerts = PackageAlertsTask()
302 diaSrcId = 1234
303 ccdData = packageAlerts.createCcdDataCutout(
304 self.exposure,
305 self.exposure.getWcs().getSkyOrigin(),
306 self.exposure.getWcs().getPixelOrigin(),
307 self.exposure.getBBox().getDimensions(),
308 self.exposure.getPhotoCalib(),
309 diaSrcId)
310 calibExposure = self.exposure.getPhotoCalib().calibrateImage(
311 self.exposure.getMaskedImage())
313 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd,
314 self.cutoutWcs.wcs.cd)
315 self.assertFloatsAlmostEqual(ccdData.data,
316 calibExposure.getImage().array)
317 self.assertFloatsAlmostEqual(ccdData.psf,
318 self.exposure.psf.computeKernelImage(self.center).array)
320 ccdData = packageAlerts.createCcdDataCutout(
321 self.exposure,
322 geom.SpherePoint(0, 0, geom.degrees),
323 geom.Point2D(-100, -100),
324 self.exposure.getBBox().getDimensions(),
325 self.exposure.getPhotoCalib(),
326 diaSrcId)
327 self.assertTrue(ccdData is None)
329 def testMakeLocalTransformMatrix(self):
330 """Test that the local WCS approximation is correct.
331 """
332 packageAlerts = PackageAlertsTask()
334 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
335 cutout = self.exposure.getCutout(sphPoint,
336 geom.Extent2I(self.cutoutSize,
337 self.cutoutSize))
338 cd = packageAlerts.makeLocalTransformMatrix(
339 cutout.getWcs(), self.center, sphPoint)
340 self.assertFloatsAlmostEqual(
341 cd,
342 cutout.getWcs().getCdMatrix(),
343 rtol=5e-10,
344 atol=5e-10)
346 def testStreamCcdDataToBytes(self):
347 """Test round tripping an CCDData cutout to bytes and back.
348 """
349 packageAlerts = PackageAlertsTask()
351 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
352 cutout = self.exposure.getCutout(sphPoint,
353 geom.Extent2I(self.cutoutSize,
354 self.cutoutSize))
355 cutoutCcdData = CCDData(
356 data=cutout.getImage().array,
357 wcs=self.cutoutWcs,
358 unit="adu")
360 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData)
361 with io.BytesIO(cutoutBytes) as bytesIO:
362 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
363 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data)
365 def testMakeAlertDict(self):
366 """Test stripping data from the various data products and into a
367 dictionary "alert".
368 """
369 packageAlerts = PackageAlertsTask()
370 dia_source_id = 1234
372 for srcIdx, diaSource in self.diaSources.iterrows():
373 sphPoint = geom.SpherePoint(diaSource["ra"],
374 diaSource["dec"],
375 geom.degrees)
376 pixelPoint = geom.Point2D(diaSource["x"], diaSource["y"])
377 cutout = self.exposure.getCutout(sphPoint,
378 geom.Extent2I(self.cutoutSize,
379 self.cutoutSize))
380 ccdCutout = packageAlerts.createCcdDataCutout(
381 cutout,
382 sphPoint,
383 pixelPoint,
384 geom.Extent2I(self.cutoutSize, self.cutoutSize),
385 cutout.getPhotoCalib(),
386 1234,
387 rotPa=cutout.visitInfo.boresightRotAngle.asDegrees())
388 cutoutBytes = packageAlerts.streamCcdDataToBytes(
389 ccdCutout)
390 objSources = self.diaSourceHistory.loc[srcIdx[0]]
391 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
392 alert = packageAlerts.makeAlertDict(
393 dia_source_id,
394 self.exposure.visitInfo.getObservationReason(),
395 self.exposure.visitInfo.getObject(),
396 diaSource,
397 self.diaObjects.loc[srcIdx[0]],
398 objSources,
399 objForcedSources,
400 ccdCutout,
401 ccdCutout,
402 ccdCutout)
403 self.assertEqual(len(alert), 13)
405 self.assertEqual(alert["diaSourceId"], dia_source_id)
406 self.assertEqual(alert["diaSource"], diaSource.to_dict())
407 self.assertIsNone(alert["observation_reason"])
408 self.assertIsNone(alert["target_name"])
409 self.assertEqual(alert["cutoutDifference"],
410 cutoutBytes)
411 self.assertEqual(alert["cutoutScience"],
412 cutoutBytes)
413 self.assertEqual(alert["cutoutTemplate"],
414 cutoutBytes)
415 science_cutout = CCDData.read(io.BytesIO(alert["cutoutScience"]),
416 format="fits")
417 template_cutout = CCDData.read(io.BytesIO(alert["cutoutTemplate"]),
418 format="fits")
419 self.assertAlmostEqual(science_cutout.header["ROTPA"],
420 template_cutout.header["ROTPA"])
422 def testMakeAlertDictSchedulerFields(self):
423 """Test non-null scheduler fields pass through as expected.
425 """
426 packageAlerts = PackageAlertsTask()
427 dia_source_id = 1234
429 for srcIdx, diaSource in self.diaSources.iterrows():
430 objSources = self.diaSourceHistory.loc[srcIdx[0]]
431 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
432 obs_reason = f"obs_reason_{srcIdx}",
433 target = f"target_name_{srcIdx}",
434 alert = packageAlerts.makeAlertDict(
435 dia_source_id,
436 obs_reason,
437 target,
438 diaSource,
439 self.diaObjects.loc[srcIdx[0]],
440 objSources,
441 objForcedSources,
442 None,
443 None,
444 None)
445 self.assertEqual(len(alert), 13)
447 self.assertEqual(alert["observation_reason"], obs_reason)
448 self.assertEqual(alert["target_name"], target)
450 def testCutoutRotpa(self):
451 """Test that the ROTPA header keyword matches the boresightRotAngle from visitInfo.
452 """
453 packageAlerts = PackageAlertsTask()
455 # Create a cutout using existing test exposure
456 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
457 cutout = self.exposure.getCutout(sphPoint,
458 geom.Extent2I(self.cutoutSize,
459 self.cutoutSize))
461 # Create CCDData cutout
462 ccdCutout = packageAlerts.createCcdDataCutout(
463 cutout,
464 sphPoint,
465 self.center,
466 geom.Extent2I(self.cutoutSize, self.cutoutSize),
467 cutout.getPhotoCalib(),
468 1234,
469 rotPa=cutout.visitInfo.boresightRotAngle.asDegrees())
471 wcs = self.exposure.wcs
472 cd_matrix = wcs.getCdMatrix() # This gets the CD matrix elements
473 angle_rad = LocalWcs.computePositionAngle([0],
474 [cd_matrix[0, 0]], # CD1_1
475 [cd_matrix[0, 1]], # CD1_2
476 [cd_matrix[1, 0]], # CD2_1
477 [cd_matrix[1, 1]] # CD2_2
478 )
479 wcs_pa = geom.Angle(angle_rad.iloc[0], geom.radians).asDegrees()
481 cutoutBytes = packageAlerts.streamCcdDataToBytes(ccdCutout)
482 with io.BytesIO(cutoutBytes) as bytesIO:
483 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
485 self.assertIn('ROTPA', cutoutFromBytes.header)
486 self.assertAlmostEqual(
487 cutoutFromBytes.header['ROTPA'],
488 self.exposure.visitInfo.boresightRotAngle.asDegrees())
489 self.assertAlmostEqual(
490 cutoutFromBytes.header['ROTPA'],
491 wcs_pa, places=4)
493 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
494 def test_produceAlerts_empty_password(self):
495 """ Test that produceAlerts raises if the password is empty or missing.
496 """
497 self.environ['AP_KAFKA_PRODUCER_PASSWORD'] = ""
498 with self.assertRaisesRegex(ValueError, "Kafka password"):
499 packConfig = PackageAlertsConfig(doProduceAlerts=True)
500 PackageAlertsTask(config=packConfig)
502 del self.environ['AP_KAFKA_PRODUCER_PASSWORD']
503 with self.assertRaisesRegex(ValueError, "Kafka password"):
504 packConfig = PackageAlertsConfig(doProduceAlerts=True)
505 PackageAlertsTask(config=packConfig)
507 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
508 def test_produceAlerts_empty_username(self):
509 """ Test that produceAlerts raises if the username is empty or missing.
510 """
511 self.environ['AP_KAFKA_PRODUCER_USERNAME'] = ""
512 with self.assertRaisesRegex(ValueError, "Kafka username"):
513 packConfig = PackageAlertsConfig(doProduceAlerts=True)
514 PackageAlertsTask(config=packConfig)
516 del self.environ['AP_KAFKA_PRODUCER_USERNAME']
517 with self.assertRaisesRegex(ValueError, "Kafka username"):
518 packConfig = PackageAlertsConfig(doProduceAlerts=True)
519 PackageAlertsTask(config=packConfig)
521 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
522 def test_produceAlerts_empty_server(self):
523 """ Test that produceAlerts raises if the server is empty or missing.
524 """
525 self.environ['AP_KAFKA_SERVER'] = ""
526 with self.assertRaisesRegex(ValueError, "Kafka server"):
527 packConfig = PackageAlertsConfig(doProduceAlerts=True)
528 PackageAlertsTask(config=packConfig)
530 del self.environ['AP_KAFKA_SERVER']
531 with self.assertRaisesRegex(ValueError, "Kafka server"):
532 packConfig = PackageAlertsConfig(doProduceAlerts=True)
533 PackageAlertsTask(config=packConfig)
535 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
536 def test_produceAlerts_empty_topic(self):
537 """ Test that produceAlerts raises if the topic is empty or missing.
538 """
539 self.environ['AP_KAFKA_TOPIC'] = ""
540 with self.assertRaisesRegex(ValueError, "Kafka topic"):
541 packConfig = PackageAlertsConfig(doProduceAlerts=True)
542 PackageAlertsTask(config=packConfig)
544 del self.environ['AP_KAFKA_TOPIC']
545 with self.assertRaisesRegex(ValueError, "Kafka topic"):
546 packConfig = PackageAlertsConfig(doProduceAlerts=True)
547 PackageAlertsTask(config=packConfig)
549 @patch('confluent_kafka.Producer')
550 @patch.object(PackageAlertsTask, '_server_check')
551 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
552 def test_produceAlerts_success(self, mock_server_check, mock_producer):
553 """ Test that produceAlerts calls the producer on all provided alerts
554 when the alerts are all under the batch size limit.
555 """
556 packConfig = PackageAlertsConfig(doProduceAlerts=True)
557 packageAlerts = PackageAlertsTask(config=packConfig)
558 alerts = [mock_alert(1), mock_alert(2), mock_ss_alert(3, 3)]
560 # Create a variable and assign it an instance of the patched kafka producer
561 producer_instance = mock_producer.return_value
562 producer_instance.produce = Mock()
563 producer_instance.flush = Mock()
564 unix_midpoint = self.exposure.visitInfo.date.toAstropy().tai.unix
565 exposure_time = self.exposure.visitInfo.exposureTime
566 packageAlerts.produceAlerts(alerts, VISIT, DETECTOR, unix_midpoint, exposure_time)
568 self.assertEqual(mock_server_check.call_count, 1)
569 self.assertEqual(producer_instance.produce.call_count, len(alerts))
570 self.assertEqual(producer_instance.flush.call_count, len(alerts)+1)
572 @patch('confluent_kafka.Producer')
573 @patch.object(PackageAlertsTask, '_server_check')
574 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
575 def test_produceAlerts_one_failure(self, mock_server_check, mock_producer):
576 """ Test that produceAlerts correctly fails on one alert
577 and is writing the failure to disk.
578 """
579 counter = 0
581 def mock_produce(*args, **kwargs):
582 nonlocal counter
583 counter += 1
584 if counter == 2:
585 raise KafkaException
586 else:
587 return
589 packConfig = PackageAlertsConfig(doProduceAlerts=True, doWriteFailedAlerts=True)
590 packageAlerts = PackageAlertsTask(config=packConfig)
592 patcher = patch("builtins.open")
593 patch_open = patcher.start()
594 alerts = [mock_alert(1), mock_alert(2), mock_alert(3), mock_ss_alert(4, 4)]
595 unix_midpoint = self.exposure.visitInfo.date.toAstropy().tai.unix
596 exposure_time = self.exposure.visitInfo.exposureTime
598 producer_instance = mock_producer.return_value
599 producer_instance.produce = Mock(side_effect=mock_produce)
600 producer_instance.flush = Mock()
601 packageAlerts.produceAlerts(alerts, VISIT, DETECTOR, unix_midpoint, exposure_time)
603 self.assertEqual(mock_server_check.call_count, 1)
604 self.assertEqual(producer_instance.produce.call_count, len(alerts))
605 self.assertEqual(patch_open.call_count, 1)
606 self.assertIn(f"{VISIT}_{DETECTOR}_2.avro", patch_open.call_args.args[0])
607 # Because one produce raises, we call flush one fewer times than in the success
608 # test above.
609 self.assertEqual(producer_instance.flush.call_count, len(alerts))
610 patcher.stop()
612 @patch.object(PackageAlertsTask, '_server_check')
613 def testRun_without_produce(self, mock_server_check):
614 """Test the run method of package alerts with produce set to False and
615 doWriteAlerts set to true.
616 """
617 packConfig = PackageAlertsConfig(doWriteAlerts=True)
618 with tempfile.TemporaryDirectory(prefix='alerts') as tempdir:
619 packConfig.alertWriteLocation = tempdir
620 packageAlerts = PackageAlertsTask(config=packConfig)
622 packageAlerts.run(self.diaSources,
623 self.diaObjects,
624 self.diaSourceHistory,
625 self.diaForcedSources,
626 self.exposure,
627 self.exposure,
628 self.exposure)
630 self.assertEqual(mock_server_check.call_count, 0)
632 with open(os.path.join(tempdir, f"{VISIT}_{DETECTOR}.avro"), 'rb') as f:
633 writer_schema, data_stream = \
634 packageAlerts.alertSchema.retrieve_alerts(f)
635 data = list(data_stream)
637 self.assertEqual(len(data), len(self.diaSources))
638 for idx, alert in enumerate(data):
639 for key, value in alert["diaSource"].items():
640 if isinstance(value, float):
641 if np.isnan(self.diaSources.iloc[idx][key]):
642 self.assertTrue(np.isnan(value))
643 else:
644 self.assertAlmostEqual(
645 1 - value / self.diaSources.iloc[idx][key],
646 0.)
647 elif value is None:
648 self.assertTrue(pd.isna(self.diaSources.iloc[idx][key]))
649 else:
650 self.assertEqual(value, self.diaSources.iloc[idx][key])
651 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
652 alert["diaSource"]["dec"],
653 geom.degrees)
654 pixelPoint = geom.Point2D(alert["diaSource"]["x"], alert["diaSource"]["y"])
655 cutout = self.exposure.getCutout(sphPoint,
656 geom.Extent2I(self.cutoutSize,
657 self.cutoutSize))
658 ccdCutout = packageAlerts.createCcdDataCutout(
659 cutout,
660 sphPoint,
661 pixelPoint,
662 geom.Extent2I(self.cutoutSize, self.cutoutSize),
663 cutout.getPhotoCalib(),
664 1234,
665 rotPa=cutout.visitInfo.boresightRotAngle.asDegrees())
667 self.assertEqual(alert["cutoutDifference"],
668 packageAlerts.streamCcdDataToBytes(ccdCutout))
670 self.assertEqual(alert["cutoutDifference"],
671 packageAlerts.streamCcdDataToBytes(ccdCutout))
673 @patch.object(PackageAlertsTask, '_server_check')
674 def testRun_without_produce_use_averagePsf(self, mock_server_check):
675 """Test the run method of package alerts with produce set to False and
676 doWriteAlerts set to true.
677 """
678 packConfig = PackageAlertsConfig(doWriteAlerts=True)
679 with tempfile.TemporaryDirectory(prefix='alerts') as tempdir:
680 packConfig.alertWriteLocation = tempdir
681 packConfig.useAveragePsf = True
682 packageAlerts = PackageAlertsTask(config=packConfig)
684 packageAlerts.run(self.diaSources,
685 self.diaObjects,
686 self.diaSourceHistory,
687 self.diaForcedSources,
688 self.exposure,
689 self.exposure,
690 self.exposure)
692 self.assertEqual(mock_server_check.call_count, 0)
694 with open(os.path.join(tempdir, f"{VISIT}_{DETECTOR}.avro"), 'rb') as f:
695 writer_schema, data_stream = \
696 packageAlerts.alertSchema.retrieve_alerts(f)
697 data = list(data_stream)
699 self.assertEqual(len(data), len(self.diaSources))
700 for idx, alert in enumerate(data):
701 for key, value in alert["diaSource"].items():
702 if isinstance(value, float):
703 if np.isnan(self.diaSources.iloc[idx][key]):
704 self.assertTrue(np.isnan(value))
705 else:
706 self.assertAlmostEqual(
707 1 - value / self.diaSources.iloc[idx][key],
708 0.)
709 elif value is None:
710 self.assertTrue(pd.isna(self.diaSources.iloc[idx][key]))
711 else:
712 self.assertEqual(value, self.diaSources.iloc[idx][key])
713 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
714 alert["diaSource"]["dec"],
715 geom.degrees)
716 pixelPoint = geom.Point2D(alert["diaSource"]["x"], alert["diaSource"]["y"])
717 cutout = self.exposure.getCutout(sphPoint,
718 geom.Extent2I(self.cutoutSize,
719 self.cutoutSize))
720 ccdCutout = packageAlerts.createCcdDataCutout(
721 cutout,
722 sphPoint,
723 pixelPoint,
724 geom.Extent2I(self.cutoutSize, self.cutoutSize),
725 cutout.getPhotoCalib(),
726 1234,
727 rotPa=cutout.visitInfo.boresightRotAngle.asDegrees())
728 self.assertEqual(alert["cutoutDifference"],
729 packageAlerts.streamCcdDataToBytes(ccdCutout))
731 @patch.object(PackageAlertsTask, 'produceAlerts')
732 @patch('confluent_kafka.Producer')
733 @patch.object(PackageAlertsTask, '_server_check')
734 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
735 def testRun_with_produce(self, mock_produceAlerts, mock_server_check, mock_producer):
736 """Test that packageAlerts calls produceAlerts when doProduceAlerts
737 is set to True.
738 """
739 packConfig = PackageAlertsConfig(doProduceAlerts=True)
740 packageAlerts = PackageAlertsTask(config=packConfig)
742 packageAlerts.run(self.diaSources,
743 self.diaObjects,
744 self.diaSourceHistory,
745 self.diaForcedSources,
746 self.exposure,
747 self.exposure,
748 self.exposure)
749 self.assertEqual(mock_server_check.call_count, 1)
750 self.assertEqual(mock_produceAlerts.call_count, 1)
752 def test_serialize_alert_round_trip(self):
753 """Test that values in the alert packet exactly round trip.
754 """
755 packClass = PackageAlertsConfig()
756 packageAlerts = PackageAlertsTask(config=packClass)
758 alert = mock_ss_alert(1, 1)
759 serialized = PackageAlertsTask._serializeAlert(packageAlerts, alert)
760 deserialized = _deserialize_alert(serialized)
761 for table in ['diaSource', 'ssSource', 'mpc_orbits']:
762 for field in alert[table]:
763 self.assertEqual(alert[table][field], deserialized[table][field])
765 self.assertEqual(1, deserialized["diaSourceId"])
767 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
768 def test_server_check(self):
770 with self.assertRaisesRegex(KafkaException, "_TRANSPORT"):
771 packConfig = PackageAlertsConfig(doProduceAlerts=True)
772 PackageAlertsTask(config=packConfig)
775class MemoryTester(lsst.utils.tests.MemoryTestCase):
776 pass
779def setup_module(module):
780 lsst.utils.tests.init()
783if __name__ == "__main__": 783 ↛ 784line 783 didn't jump to line 784 because the condition on line 783 was never true
784 lsst.utils.tests.init()
785 unittest.main()