Coverage for tests / test_packageAlerts.py: 19%
363 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-07 08:39 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-07 08:39 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import io
23import os
25import numpy as np
26import pandas as pd
27import tempfile
28import unittest
29from unittest.mock import patch, Mock
30from astropy import wcs
31from astropy.nddata import CCDData
32import fastavro
33try:
34 import confluent_kafka
35 from confluent_kafka import KafkaException
36except ImportError:
37 confluent_kafka = None
39import lsst.alert.packet as alertPack
40from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask
41from lsst.ap.association.utils import readSchemaFromApdb
42from lsst.afw.cameraGeom.testUtils import DetectorWrapper
43import lsst.afw.image as afwImage
44from lsst.daf.base import DateTime
45from lsst.dax.apdb import Apdb, ApdbSql
46import lsst.geom as geom
47import lsst.meas.base.tests
48from lsst.sphgeom import Box
49import lsst.utils.tests
50from lsst.pipe.tasks.functors import LocalWcs
51from lsst.pipe.tasks.schemaUtils import convertDataFrameToSdmSchema
52import utils_tests
55def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
56 """Run object and source catalogs through the Apdb to get the correct
57 table schemas.
59 Parameters
60 ----------
61 objects : `pandas.DataFrame`
62 Set of test DiaObjects to round trip.
63 sources : `pandas.DataFrame`
64 Set of test DiaSources to round trip.
65 forcedSources : `pandas.DataFrame`
66 Set of test DiaForcedSources to round trip.
67 dateTime : `astropy.time.Time`
68 Time for the Apdb.
70 Returns
71 -------
72 objects : `pandas.DataFrame`
73 Round tripped objects.
74 sources : `pandas.DataFrame`
75 Round tripped sources.
76 """
77 with tempfile.NamedTemporaryFile() as tmpFile:
78 apdbConfig = ApdbSql.init_database(db_url="sqlite:///" + tmpFile.name)
79 apdb = Apdb.from_config(apdbConfig)
81 wholeSky = Box.full()
82 loadedObjects = apdb.getDiaObjects(wholeSky)
83 if loadedObjects.empty:
84 diaObjects = objects
85 else:
86 diaObjects = pd.concat([loadedObjects, objects])
87 loadedDiaSources = apdb.getDiaSources(wholeSky, [], dateTime)
88 if loadedDiaSources.empty:
89 diaSources = sources
90 else:
91 diaSources = pd.concat([loadedDiaSources, sources])
92 loadedDiaForcedSources = apdb.getDiaForcedSources(wholeSky, [], dateTime)
93 if loadedDiaForcedSources.empty:
94 diaForcedSources = forcedSources
95 else:
96 diaForcedSources = pd.concat([loadedDiaForcedSources, forcedSources])
98 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources)
100 diaObjects = apdb.getDiaObjects(wholeSky)
101 diaSources = apdb.getDiaSources(wholeSky,
102 np.unique(diaObjects["diaObjectId"]),
103 dateTime)
104 diaForcedSources = apdb.getDiaForcedSources(
105 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime)
107 diaObjects.set_index("diaObjectId", drop=False, inplace=True)
108 diaSources.set_index(["diaObjectId", "band", "diaSourceId"],
109 drop=False,
110 inplace=True)
111 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True)
113 # apply SDM type standardization to catch pandas typing issues
114 schema = readSchemaFromApdb(apdb)
115 diaObjects = convertDataFrameToSdmSchema(schema, diaObjects, tableName="DiaObject", skipIndex=True)
116 diaSources = convertDataFrameToSdmSchema(schema, diaSources, tableName="DiaSource", skipIndex=True)
117 diaForcedSources = convertDataFrameToSdmSchema(schema, diaForcedSources, tableName="DiaForcedSource",
118 skipIndex=True)
120 return (diaObjects, diaSources, diaForcedSources)
123VISIT = 2
124DETECTOR = 42
127def mock_alert(dia_source_id):
128 """Generate a minimal mock alert.
129 """
130 return {
131 "diaSourceId": dia_source_id,
132 "diaSource": {
133 "midpointMjdTai": 5,
134 "diaSourceId": 1234,
135 "visit": VISIT,
136 "detector": DETECTOR,
137 "band": 'g',
138 "ra": 12.5,
139 "dec": -16.9,
140 # These types are 32-bit floats in the avro schema, so we have to
141 # make them that type here, so that they round trip appropriately.
142 "x": np.float32(15.7),
143 "y": np.float32(89.8),
144 "apFlux": np.float32(54.85),
145 "apFluxErr": np.float32(70.0),
146 "snr": np.float32(6.7),
147 "psfFlux": np.float32(700.0),
148 "psfFluxErr": np.float32(90.0),
149 "psfNdata": int(25),
150 "dipoleNdata": int(43),
151 "trailNdata": int(5),
152 "bboxSize": int(50),
153 # unlike in transformDiaSourceCatalog.py we need a timezone-aware
154 # version because mock_alert does not go through pandas
155 "timeProcessedMjdTai": DateTime.now().get(system=DateTime.MJD, scale=DateTime.TAI)
156 }
157 }
160def mock_ss_alert(dia_source_id, ss_object_id):
161 """Generate a minimal mock alert.
162 """
163 alert = mock_alert(dia_source_id)
164 alert['mpc_orbits'] = {
165 'id': 42,
166 'designation': '2020 AH11', # a string-typed field
167 'q': np.float64(0.99999), # a double-typed field
168 'unpacked_primary_provisional_designation': '2020 AH11',
169 'packed_primary_provisional_designation': 'J201AH1',
171 }
172 alert['ssSource'] = {
173 'diaSourceId': dia_source_id,
174 'ssObjectId': ss_object_id,
175 'eclLambda': np.float64(3.141592), # a double-typed field
176 'eclBeta': np.float64(3.141592), # a double-typed field
177 'galLon': np.float64(3.141592), # a double-typed field
178 'galLat': np.float64(3.141592), # a double-typed field
179 'helioRange': np.float32(3.141592), # a float-typed field
180 }
181 return alert
184def _deserialize_alert(alert_bytes):
185 """Deserialize an alert message from Kafka.
187 Parameters
188 ----------
189 alert_bytes : `bytes`
190 Binary-encoding serialized Avro alert, including Confluent Wire
191 Format prefix.
193 Returns
194 -------
195 alert : `dict`
196 An alert payload.
197 """
198 schema = alertPack.Schema.from_uri(str(alertPack.get_uri_to_latest_schema()))
199 content_bytes = io.BytesIO(alert_bytes[5:])
201 return fastavro.schemaless_reader(content_bytes, schema.definition)
204class TestPackageAlerts(lsst.utils.tests.TestCase):
205 def setUp(self):
206 # Create an instance of random generator with fixed seed.
207 rng = np.random.default_rng(1234)
209 patcher = patch.dict(os.environ, {"AP_KAFKA_PRODUCER_PASSWORD": "fake_password",
210 "AP_KAFKA_PRODUCER_USERNAME": "fake_username",
211 "AP_KAFKA_SERVER": "fake_server",
212 "AP_KAFKA_TOPIC": "fake_topic"})
213 self.environ = patcher.start()
214 self.addCleanup(patcher.stop)
215 self.cutoutSize = 35
216 self.center = lsst.geom.Point2D(50.1, 49.8)
217 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
218 lsst.geom.Extent2I(140, 160))
219 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
220 self.dataset.addSource(100000.0, self.center)
221 exposure, catalog = self.dataset.realize(
222 10.0,
223 self.dataset.makeMinimalSchema(),
224 randomSeed=1234)
225 self.exposure = exposure
226 detector = DetectorWrapper(id=DETECTOR, bbox=exposure.getBBox()).detector
227 self.exposure.setDetector(detector)
229 visit = afwImage.VisitInfo(
230 id=VISIT,
231 exposureTime=200.,
232 date=DateTime("2014-05-13T17:00:00.000000000",
233 DateTime.Timescale.TAI),
234 boresightRotAngle=geom.Angle(0.785398))
235 self.exposure.info.id = 1234
236 self.exposure.info.setVisitInfo(visit)
238 self.exposure.setFilter(
239 afwImage.FilterLabel(band='g', physical="g.MP9401"))
241 diaObjects = utils_tests.makeDiaObjects(2, self.exposure, rng)
242 diaSourceHistory = utils_tests.makeDiaSources(
243 10, diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)
244 diaForcedSources = utils_tests.makeDiaForcedSources(
245 10, diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)
246 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb(
247 diaObjects,
248 diaSourceHistory,
249 diaForcedSources,
250 self.exposure.visitInfo.date.toAstropy())
251 diaSourceHistory["programId"] = 0
253 self.diaSources = diaSourceHistory.loc[[(1, "g", 9), (2, "g", 10)], :]
254 self.diaSources["bboxSize"] = self.cutoutSize
255 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9),
256 (2, "g", 10)])
258 self.cutoutWcs = wcs.WCS(naxis=2)
259 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]]
260 self.cutoutWcs.wcs.crval = [
261 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(),
262 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()]
263 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
264 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
266 def testCreateExtentMinimum(self):
267 """Test the extent creation for the cutout bbox returns a cutout with
268 the minimum cutouut size.
269 """
270 packConfig = PackageAlertsConfig()
271 # Just create a minimum less than the default cutout size.
272 packConfig.minCutoutSize = self.cutoutSize - 5
273 packageAlerts = PackageAlertsTask(config=packConfig)
274 extent = packageAlerts.createDiaSourceExtent(
275 packConfig.minCutoutSize - 5)
276 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize,
277 packConfig.minCutoutSize))
278 # Test that the cutout size is correctly increased.
279 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize)
280 self.assertTrue(extent == geom.Extent2I(self.cutoutSize,
281 self.cutoutSize))
283 def testCreateExtentMaximum(self):
284 """Test the extent creation for the cutout bbox returns a cutout with
285 the maximum cutout size.
286 """
287 packConfig = PackageAlertsConfig()
288 # Just create a maximum more than the default cutout size.
289 packConfig.maxCutoutSize = self.cutoutSize + 5
290 packageAlerts = PackageAlertsTask(config=packConfig)
291 extent = packageAlerts.createDiaSourceExtent(
292 packConfig.maxCutoutSize + 5)
293 self.assertTrue(extent == geom.Extent2I(packConfig.maxCutoutSize,
294 packConfig.maxCutoutSize))
295 # Test that the cutout size is correctly reduced.
296 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize)
297 self.assertTrue(extent == geom.Extent2I(self.cutoutSize,
298 self.cutoutSize))
300 def testCreateCcdDataCutout(self):
301 """Test that the data is being extracted into the CCDData cutout
302 correctly.
303 """
304 packageAlerts = PackageAlertsTask()
306 diaSrcId = 1234
307 ccdData = packageAlerts.createCcdDataCutout(
308 self.exposure,
309 self.exposure.getWcs().getSkyOrigin(),
310 self.exposure.getWcs().getPixelOrigin(),
311 self.exposure.getBBox().getDimensions(),
312 self.exposure.getPhotoCalib(),
313 diaSrcId)
314 calibExposure = self.exposure.getPhotoCalib().calibrateImage(
315 self.exposure.getMaskedImage())
317 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd,
318 self.cutoutWcs.wcs.cd)
319 self.assertFloatsAlmostEqual(ccdData.data,
320 calibExposure.getImage().array)
321 self.assertFloatsAlmostEqual(ccdData.psf,
322 self.exposure.psf.computeKernelImage(self.center).array)
324 ccdData = packageAlerts.createCcdDataCutout(
325 self.exposure,
326 geom.SpherePoint(0, 0, geom.degrees),
327 geom.Point2D(-100, -100),
328 self.exposure.getBBox().getDimensions(),
329 self.exposure.getPhotoCalib(),
330 diaSrcId)
331 self.assertTrue(ccdData is None)
333 def testMakeLocalTransformMatrix(self):
334 """Test that the local WCS approximation is correct.
335 """
336 packageAlerts = PackageAlertsTask()
338 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
339 cutout = self.exposure.getCutout(sphPoint,
340 geom.Extent2I(self.cutoutSize,
341 self.cutoutSize))
342 cd = packageAlerts.makeLocalTransformMatrix(
343 cutout.getWcs(), self.center, sphPoint)
344 self.assertFloatsAlmostEqual(
345 cd,
346 cutout.getWcs().getCdMatrix(),
347 rtol=5e-10,
348 atol=5e-10)
350 def testStreamCcdDataToBytes(self):
351 """Test round tripping an CCDData cutout to bytes and back.
352 """
353 packageAlerts = PackageAlertsTask()
355 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
356 cutout = self.exposure.getCutout(sphPoint,
357 geom.Extent2I(self.cutoutSize,
358 self.cutoutSize))
359 cutoutCcdData = CCDData(
360 data=cutout.getImage().array,
361 wcs=self.cutoutWcs,
362 unit="adu")
364 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData)
365 with io.BytesIO(cutoutBytes) as bytesIO:
366 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
367 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data)
369 def testMakeAlertDict(self):
370 """Test stripping data from the various data products and into a
371 dictionary "alert".
372 """
373 packageAlerts = PackageAlertsTask()
374 dia_source_id = 1234
376 for srcIdx, diaSource in self.diaSources.iterrows():
377 sphPoint = geom.SpherePoint(diaSource["ra"],
378 diaSource["dec"],
379 geom.degrees)
380 pixelPoint = geom.Point2D(diaSource["x"], diaSource["y"])
381 cutout = self.exposure.getCutout(sphPoint,
382 geom.Extent2I(self.cutoutSize,
383 self.cutoutSize))
384 ccdCutout = packageAlerts.createCcdDataCutout(
385 cutout,
386 sphPoint,
387 pixelPoint,
388 geom.Extent2I(self.cutoutSize, self.cutoutSize),
389 cutout.getPhotoCalib(),
390 1234,
391 rotPa=cutout.visitInfo.boresightRotAngle.asDegrees())
392 cutoutBytes = packageAlerts.streamCcdDataToBytes(
393 ccdCutout)
394 objSources = self.diaSourceHistory.loc[srcIdx[0]]
395 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
396 alert = packageAlerts.makeAlertDict(
397 dia_source_id,
398 self.exposure.visitInfo.getObservationReason(),
399 self.exposure.visitInfo.getObject(),
400 diaSource,
401 self.diaObjects.loc[srcIdx[0]],
402 objSources,
403 objForcedSources,
404 ccdCutout,
405 ccdCutout,
406 ccdCutout)
407 self.assertEqual(len(alert), 13)
409 self.assertEqual(alert["diaSourceId"], dia_source_id)
410 self.assertEqual(alert["diaSource"], diaSource.to_dict())
411 self.assertIsNone(alert["observation_reason"])
412 self.assertIsNone(alert["target_name"])
413 self.assertEqual(alert["cutoutDifference"],
414 cutoutBytes)
415 self.assertEqual(alert["cutoutScience"],
416 cutoutBytes)
417 self.assertEqual(alert["cutoutTemplate"],
418 cutoutBytes)
419 science_cutout = CCDData.read(io.BytesIO(alert["cutoutScience"]),
420 format="fits")
421 template_cutout = CCDData.read(io.BytesIO(alert["cutoutTemplate"]),
422 format="fits")
423 self.assertAlmostEqual(science_cutout.header["ROTPA"],
424 template_cutout.header["ROTPA"])
426 def testMakeAlertDictSchedulerFields(self):
427 """Test non-null scheduler fields pass through as expected.
429 """
430 packageAlerts = PackageAlertsTask()
431 dia_source_id = 1234
433 for srcIdx, diaSource in self.diaSources.iterrows():
434 objSources = self.diaSourceHistory.loc[srcIdx[0]]
435 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
436 obs_reason = f"obs_reason_{srcIdx}",
437 target = f"target_name_{srcIdx}",
438 alert = packageAlerts.makeAlertDict(
439 dia_source_id,
440 obs_reason,
441 target,
442 diaSource,
443 self.diaObjects.loc[srcIdx[0]],
444 objSources,
445 objForcedSources,
446 None,
447 None,
448 None)
449 self.assertEqual(len(alert), 13)
451 self.assertEqual(alert["observation_reason"], obs_reason)
452 self.assertEqual(alert["target_name"], target)
454 def testCutoutRotpa(self):
455 """Test that the ROTPA header keyword matches the boresightRotAngle from visitInfo.
456 """
457 packageAlerts = PackageAlertsTask()
459 # Create a cutout using existing test exposure
460 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
461 cutout = self.exposure.getCutout(sphPoint,
462 geom.Extent2I(self.cutoutSize,
463 self.cutoutSize))
465 # Create CCDData cutout
466 ccdCutout = packageAlerts.createCcdDataCutout(
467 cutout,
468 sphPoint,
469 self.center,
470 geom.Extent2I(self.cutoutSize, self.cutoutSize),
471 cutout.getPhotoCalib(),
472 1234,
473 rotPa=cutout.visitInfo.boresightRotAngle.asDegrees())
475 wcs = self.exposure.wcs
476 cd_matrix = wcs.getCdMatrix() # This gets the CD matrix elements
477 angle_rad = LocalWcs.computePositionAngle([0],
478 [cd_matrix[0, 0]], # CD1_1
479 [cd_matrix[0, 1]], # CD1_2
480 [cd_matrix[1, 0]], # CD2_1
481 [cd_matrix[1, 1]] # CD2_2
482 )
483 wcs_pa = geom.Angle(angle_rad.iloc[0], geom.radians).asDegrees()
485 cutoutBytes = packageAlerts.streamCcdDataToBytes(ccdCutout)
486 with io.BytesIO(cutoutBytes) as bytesIO:
487 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
489 self.assertIn('ROTPA', cutoutFromBytes.header)
490 self.assertAlmostEqual(
491 cutoutFromBytes.header['ROTPA'],
492 self.exposure.visitInfo.boresightRotAngle.asDegrees())
493 self.assertAlmostEqual(
494 cutoutFromBytes.header['ROTPA'],
495 wcs_pa, places=4)
497 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
498 def test_produceAlerts_empty_password(self):
499 """ Test that produceAlerts raises if the password is empty or missing.
500 """
501 self.environ['AP_KAFKA_PRODUCER_PASSWORD'] = ""
502 with self.assertRaisesRegex(ValueError, "Kafka password"):
503 packConfig = PackageAlertsConfig(doProduceAlerts=True)
504 PackageAlertsTask(config=packConfig)
506 del self.environ['AP_KAFKA_PRODUCER_PASSWORD']
507 with self.assertRaisesRegex(ValueError, "Kafka password"):
508 packConfig = PackageAlertsConfig(doProduceAlerts=True)
509 PackageAlertsTask(config=packConfig)
511 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
512 def test_produceAlerts_empty_username(self):
513 """ Test that produceAlerts raises if the username is empty or missing.
514 """
515 self.environ['AP_KAFKA_PRODUCER_USERNAME'] = ""
516 with self.assertRaisesRegex(ValueError, "Kafka username"):
517 packConfig = PackageAlertsConfig(doProduceAlerts=True)
518 PackageAlertsTask(config=packConfig)
520 del self.environ['AP_KAFKA_PRODUCER_USERNAME']
521 with self.assertRaisesRegex(ValueError, "Kafka username"):
522 packConfig = PackageAlertsConfig(doProduceAlerts=True)
523 PackageAlertsTask(config=packConfig)
525 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
526 def test_produceAlerts_empty_server(self):
527 """ Test that produceAlerts raises if the server is empty or missing.
528 """
529 self.environ['AP_KAFKA_SERVER'] = ""
530 with self.assertRaisesRegex(ValueError, "Kafka server"):
531 packConfig = PackageAlertsConfig(doProduceAlerts=True)
532 PackageAlertsTask(config=packConfig)
534 del self.environ['AP_KAFKA_SERVER']
535 with self.assertRaisesRegex(ValueError, "Kafka server"):
536 packConfig = PackageAlertsConfig(doProduceAlerts=True)
537 PackageAlertsTask(config=packConfig)
539 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
540 def test_produceAlerts_empty_topic(self):
541 """ Test that produceAlerts raises if the topic is empty or missing.
542 """
543 self.environ['AP_KAFKA_TOPIC'] = ""
544 with self.assertRaisesRegex(ValueError, "Kafka topic"):
545 packConfig = PackageAlertsConfig(doProduceAlerts=True)
546 PackageAlertsTask(config=packConfig)
548 del self.environ['AP_KAFKA_TOPIC']
549 with self.assertRaisesRegex(ValueError, "Kafka topic"):
550 packConfig = PackageAlertsConfig(doProduceAlerts=True)
551 PackageAlertsTask(config=packConfig)
553 @patch('confluent_kafka.Producer')
554 @patch.object(PackageAlertsTask, '_server_check')
555 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
556 def test_produceAlerts_success(self, mock_server_check, mock_producer):
557 """ Test that produceAlerts calls the producer on all provided alerts
558 when the alerts are all under the batch size limit.
559 """
560 packConfig = PackageAlertsConfig(doProduceAlerts=True)
561 packageAlerts = PackageAlertsTask(config=packConfig)
562 alerts = [mock_alert(1), mock_alert(2), mock_ss_alert(3, 3)]
564 # Create a variable and assign it an instance of the patched kafka producer
565 producer_instance = mock_producer.return_value
566 producer_instance.produce = Mock()
567 producer_instance.flush = Mock()
568 unix_midpoint = self.exposure.visitInfo.date.toAstropy().tai.unix
569 exposure_time = self.exposure.visitInfo.exposureTime
570 packageAlerts.produceAlerts(alerts, VISIT, DETECTOR, unix_midpoint, exposure_time)
572 self.assertEqual(mock_server_check.call_count, 1)
573 self.assertEqual(producer_instance.produce.call_count, len(alerts))
574 self.assertEqual(producer_instance.flush.call_count, len(alerts)+1)
576 @patch('confluent_kafka.Producer')
577 @patch.object(PackageAlertsTask, '_server_check')
578 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
579 def test_produceAlerts_one_failure(self, mock_server_check, mock_producer):
580 """ Test that produceAlerts correctly fails on one alert
581 and is writing the failure to disk.
582 """
583 counter = 0
585 def mock_produce(*args, **kwargs):
586 nonlocal counter
587 counter += 1
588 if counter == 2:
589 raise KafkaException
590 else:
591 return
593 packConfig = PackageAlertsConfig(doProduceAlerts=True, doWriteFailedAlerts=True)
594 packageAlerts = PackageAlertsTask(config=packConfig)
596 patcher = patch("builtins.open")
597 patch_open = patcher.start()
598 alerts = [mock_alert(1), mock_alert(2), mock_alert(3), mock_ss_alert(4, 4)]
599 unix_midpoint = self.exposure.visitInfo.date.toAstropy().tai.unix
600 exposure_time = self.exposure.visitInfo.exposureTime
602 producer_instance = mock_producer.return_value
603 producer_instance.produce = Mock(side_effect=mock_produce)
604 producer_instance.flush = Mock()
605 packageAlerts.produceAlerts(alerts, VISIT, DETECTOR, unix_midpoint, exposure_time)
607 self.assertEqual(mock_server_check.call_count, 1)
608 self.assertEqual(producer_instance.produce.call_count, len(alerts))
609 self.assertEqual(patch_open.call_count, 1)
610 self.assertIn(f"{VISIT}_{DETECTOR}_2.avro", patch_open.call_args.args[0])
611 # Because one produce raises, we call flush one fewer times than in the success
612 # test above.
613 self.assertEqual(producer_instance.flush.call_count, len(alerts))
614 patcher.stop()
616 @patch.object(PackageAlertsTask, '_server_check')
617 def testRun_without_produce(self, mock_server_check):
618 """Test the run method of package alerts with produce set to False and
619 doWriteAlerts set to true.
620 """
621 packConfig = PackageAlertsConfig(doWriteAlerts=True)
622 with tempfile.TemporaryDirectory(prefix='alerts') as tempdir:
623 packConfig.alertWriteLocation = tempdir
624 packageAlerts = PackageAlertsTask(config=packConfig)
626 packageAlerts.run(self.diaSources,
627 self.diaObjects,
628 self.diaSourceHistory,
629 self.diaForcedSources,
630 self.exposure,
631 self.exposure,
632 self.exposure)
634 self.assertEqual(mock_server_check.call_count, 0)
636 with open(os.path.join(tempdir, f"{VISIT}_{DETECTOR}.avro"), 'rb') as f:
637 writer_schema, data_stream = \
638 packageAlerts.alertSchema.retrieve_alerts(f)
639 data = list(data_stream)
641 self.assertEqual(len(data), len(self.diaSources))
642 for idx, alert in enumerate(data):
643 for key, value in alert["diaSource"].items():
644 if isinstance(value, float):
645 if np.isnan(self.diaSources.iloc[idx][key]):
646 self.assertTrue(np.isnan(value))
647 else:
648 self.assertAlmostEqual(
649 1 - value / self.diaSources.iloc[idx][key],
650 0.)
651 elif value is None:
652 self.assertTrue(pd.isna(self.diaSources.iloc[idx][key]))
653 else:
654 self.assertEqual(value, self.diaSources.iloc[idx][key])
655 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
656 alert["diaSource"]["dec"],
657 geom.degrees)
658 pixelPoint = geom.Point2D(alert["diaSource"]["x"], alert["diaSource"]["y"])
659 cutout = self.exposure.getCutout(sphPoint,
660 geom.Extent2I(self.cutoutSize,
661 self.cutoutSize))
662 ccdCutout = packageAlerts.createCcdDataCutout(
663 cutout,
664 sphPoint,
665 pixelPoint,
666 geom.Extent2I(self.cutoutSize, self.cutoutSize),
667 cutout.getPhotoCalib(),
668 1234,
669 rotPa=cutout.visitInfo.boresightRotAngle.asDegrees())
671 self.assertEqual(alert["cutoutDifference"],
672 packageAlerts.streamCcdDataToBytes(ccdCutout))
674 self.assertEqual(alert["cutoutDifference"],
675 packageAlerts.streamCcdDataToBytes(ccdCutout))
677 @patch.object(PackageAlertsTask, '_server_check')
678 def testRun_without_produce_use_averagePsf(self, mock_server_check):
679 """Test the run method of package alerts with produce set to False and
680 doWriteAlerts set to true.
681 """
682 packConfig = PackageAlertsConfig(doWriteAlerts=True)
683 with tempfile.TemporaryDirectory(prefix='alerts') as tempdir:
684 packConfig.alertWriteLocation = tempdir
685 packConfig.useAveragePsf = True
686 packageAlerts = PackageAlertsTask(config=packConfig)
688 packageAlerts.run(self.diaSources,
689 self.diaObjects,
690 self.diaSourceHistory,
691 self.diaForcedSources,
692 self.exposure,
693 self.exposure,
694 self.exposure)
696 self.assertEqual(mock_server_check.call_count, 0)
698 with open(os.path.join(tempdir, f"{VISIT}_{DETECTOR}.avro"), 'rb') as f:
699 writer_schema, data_stream = \
700 packageAlerts.alertSchema.retrieve_alerts(f)
701 data = list(data_stream)
703 self.assertEqual(len(data), len(self.diaSources))
704 for idx, alert in enumerate(data):
705 for key, value in alert["diaSource"].items():
706 if isinstance(value, float):
707 if np.isnan(self.diaSources.iloc[idx][key]):
708 self.assertTrue(np.isnan(value))
709 else:
710 self.assertAlmostEqual(
711 1 - value / self.diaSources.iloc[idx][key],
712 0.)
713 elif value is None:
714 self.assertTrue(pd.isna(self.diaSources.iloc[idx][key]))
715 else:
716 self.assertEqual(value, self.diaSources.iloc[idx][key])
717 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
718 alert["diaSource"]["dec"],
719 geom.degrees)
720 pixelPoint = geom.Point2D(alert["diaSource"]["x"], alert["diaSource"]["y"])
721 cutout = self.exposure.getCutout(sphPoint,
722 geom.Extent2I(self.cutoutSize,
723 self.cutoutSize))
724 ccdCutout = packageAlerts.createCcdDataCutout(
725 cutout,
726 sphPoint,
727 pixelPoint,
728 geom.Extent2I(self.cutoutSize, self.cutoutSize),
729 cutout.getPhotoCalib(),
730 1234,
731 rotPa=cutout.visitInfo.boresightRotAngle.asDegrees())
732 self.assertEqual(alert["cutoutDifference"],
733 packageAlerts.streamCcdDataToBytes(ccdCutout))
735 @patch.object(PackageAlertsTask, 'produceAlerts')
736 @patch('confluent_kafka.Producer')
737 @patch.object(PackageAlertsTask, '_server_check')
738 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
739 def testRun_with_produce(self, mock_produceAlerts, mock_server_check, mock_producer):
740 """Test that packageAlerts calls produceAlerts when doProduceAlerts
741 is set to True.
742 """
743 packConfig = PackageAlertsConfig(doProduceAlerts=True)
744 packageAlerts = PackageAlertsTask(config=packConfig)
746 packageAlerts.run(self.diaSources,
747 self.diaObjects,
748 self.diaSourceHistory,
749 self.diaForcedSources,
750 self.exposure,
751 self.exposure,
752 self.exposure)
753 self.assertEqual(mock_server_check.call_count, 1)
754 self.assertEqual(mock_produceAlerts.call_count, 1)
756 def test_serialize_alert_round_trip(self):
757 """Test that values in the alert packet exactly round trip.
758 """
759 packClass = PackageAlertsConfig()
760 packageAlerts = PackageAlertsTask(config=packClass)
762 alert = mock_ss_alert(1, 1)
763 serialized = PackageAlertsTask._serializeAlert(packageAlerts, alert)
764 deserialized = _deserialize_alert(serialized)
765 for table in ['diaSource', 'ssSource', 'mpc_orbits']:
766 for field in alert[table]:
767 self.assertEqual(alert[table][field], deserialized[table][field])
769 self.assertEqual(1, deserialized["diaSourceId"])
771 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
772 def test_server_check(self):
774 with self.assertRaisesRegex(KafkaException, "_TRANSPORT"):
775 packConfig = PackageAlertsConfig(doProduceAlerts=True)
776 PackageAlertsTask(config=packConfig)
779class MemoryTester(lsst.utils.tests.MemoryTestCase):
780 pass
783def setup_module(module):
784 lsst.utils.tests.init()
787if __name__ == "__main__": 787 ↛ 788line 787 didn't jump to line 788 because the condition on line 787 was never true
788 lsst.utils.tests.init()
789 unittest.main()