Coverage for tests/test_packageAlerts.py: 31%
264 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 09:33 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 09:33 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import io
23import os
25import numpy as np
26import pandas as pd
27import tempfile
28import unittest
29from unittest.mock import patch, Mock
30from astropy import wcs
31from astropy.nddata import CCDData
32import fastavro
33try:
34 import confluent_kafka
35 from confluent_kafka import KafkaException
36except ImportError:
37 confluent_kafka = None
39import lsst.alert.packet as alertPack
40from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask
41from lsst.afw.cameraGeom.testUtils import DetectorWrapper
42import lsst.afw.image as afwImage
43import lsst.daf.base as dafBase
44from lsst.dax.apdb import Apdb, ApdbSql
45import lsst.geom as geom
46import lsst.meas.base.tests
47from lsst.sphgeom import Box
48import lsst.utils.tests
49import utils_tests
52def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
53 """Run object and source catalogs through the Apdb to get the correct
54 table schemas.
56 Parameters
57 ----------
58 objects : `pandas.DataFrame`
59 Set of test DiaObjects to round trip.
60 sources : `pandas.DataFrame`
61 Set of test DiaSources to round trip.
62 forcedSources : `pandas.DataFrame`
63 Set of test DiaForcedSources to round trip.
64 dateTime : `astropy.time.Time`
65 Time for the Apdb.
67 Returns
68 -------
69 objects : `pandas.DataFrame`
70 Round tripped objects.
71 sources : `pandas.DataFrame`
72 Round tripped sources.
73 """
74 with tempfile.NamedTemporaryFile() as tmpFile:
75 apdbConfig = ApdbSql.init_database(db_url="sqlite:///" + tmpFile.name)
76 apdb = Apdb.from_config(apdbConfig)
78 wholeSky = Box.full()
79 diaObjects = pd.concat([apdb.getDiaObjects(wholeSky), objects])
80 diaSources = pd.concat(
81 [apdb.getDiaSources(wholeSky, [], dateTime), sources])
82 diaForcedSources = pd.concat(
83 [apdb.getDiaForcedSources(wholeSky, [], dateTime), forcedSources])
85 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources)
87 diaObjects = apdb.getDiaObjects(wholeSky)
88 diaSources = apdb.getDiaSources(wholeSky,
89 np.unique(diaObjects["diaObjectId"]),
90 dateTime)
91 diaForcedSources = apdb.getDiaForcedSources(
92 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime)
94 diaObjects.set_index("diaObjectId", drop=False, inplace=True)
95 diaSources.set_index(["diaObjectId", "band", "diaSourceId"],
96 drop=False,
97 inplace=True)
98 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True)
100 return (diaObjects, diaSources, diaForcedSources)
103VISIT = 2
104DETECTOR = 42
107def mock_alert(alert_id):
108 """Generate a minimal mock alert.
109 """
110 return {
111 "alertId": alert_id,
112 "diaSource": {
113 "midpointMjdTai": 5,
114 "diaSourceId": 1234,
115 "visit": VISIT,
116 "detector": DETECTOR,
117 "band": 'g',
118 "ra": 12.5,
119 "dec": -16.9,
120 # These types are 32-bit floats in the avro schema, so we have to
121 # make them that type here, so that they round trip appropriately.
122 "x": np.float32(15.7),
123 "y": np.float32(89.8),
124 "apFlux": np.float32(54.85),
125 "apFluxErr": np.float32(70.0),
126 "snr": np.float32(6.7),
127 "psfFlux": np.float32(700.0),
128 "psfFluxErr": np.float32(90.0),
129 }
130 }
133def _deserialize_alert(alert_bytes):
134 """Deserialize an alert message from Kafka.
136 Parameters
137 ----------
138 alert_bytes : `bytes`
139 Binary-encoding serialized Avro alert, including Confluent Wire
140 Format prefix.
142 Returns
143 -------
144 alert : `dict`
145 An alert payload.
146 """
147 schema = alertPack.Schema.from_uri(str(alertPack.get_uri_to_latest_schema()))
148 content_bytes = io.BytesIO(alert_bytes[5:])
150 return fastavro.schemaless_reader(content_bytes, schema.definition)
153class TestPackageAlerts(lsst.utils.tests.TestCase):
154 def setUp(self):
155 # Create an instance of random generator with fixed seed.
156 rng = np.random.default_rng(1234)
158 patcher = patch.dict(os.environ, {"AP_KAFKA_PRODUCER_PASSWORD": "fake_password",
159 "AP_KAFKA_PRODUCER_USERNAME": "fake_username",
160 "AP_KAFKA_SERVER": "fake_server",
161 "AP_KAFKA_TOPIC": "fake_topic"})
162 self.environ = patcher.start()
163 self.addCleanup(patcher.stop)
164 self.cutoutSize = 35
165 self.center = lsst.geom.Point2D(50.1, 49.8)
166 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
167 lsst.geom.Extent2I(140, 160))
168 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
169 self.dataset.addSource(100000.0, self.center)
170 exposure, catalog = self.dataset.realize(
171 10.0,
172 self.dataset.makeMinimalSchema(),
173 randomSeed=1234)
174 self.exposure = exposure
175 detector = DetectorWrapper(id=DETECTOR, bbox=exposure.getBBox()).detector
176 self.exposure.setDetector(detector)
178 visit = afwImage.VisitInfo(
179 id=VISIT,
180 exposureTime=200.,
181 date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
182 dafBase.DateTime.Timescale.TAI))
183 self.exposure.info.id = 1234
184 self.exposure.info.setVisitInfo(visit)
186 self.exposure.setFilter(
187 afwImage.FilterLabel(band='g', physical="g.MP9401"))
189 diaObjects = utils_tests.makeDiaObjects(2, self.exposure, rng)
190 diaSourceHistory = utils_tests.makeDiaSources(
191 10, diaObjects["diaObjectId"], self.exposure, rng)
192 diaForcedSources = utils_tests.makeDiaForcedSources(
193 10, diaObjects["diaObjectId"], self.exposure, rng)
194 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb(
195 diaObjects,
196 diaSourceHistory,
197 diaForcedSources,
198 self.exposure.visitInfo.date.toAstropy())
199 diaSourceHistory["programId"] = 0
201 self.diaSources = diaSourceHistory.loc[[(1, "g", 9), (2, "g", 10)], :]
202 self.diaSources["bboxSize"] = self.cutoutSize
203 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9),
204 (2, "g", 10)])
206 self.cutoutWcs = wcs.WCS(naxis=2)
207 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]]
208 self.cutoutWcs.wcs.crval = [
209 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(),
210 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()]
211 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
212 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
214 def testCreateExtent(self):
215 """Test the extent creation for the cutout bbox.
216 """
217 packConfig = PackageAlertsConfig()
218 # Just create a minimum less than the default cutout.
219 packConfig.minCutoutSize = self.cutoutSize - 5
220 packageAlerts = PackageAlertsTask(config=packConfig)
221 extent = packageAlerts.createDiaSourceExtent(
222 packConfig.minCutoutSize - 5)
223 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize,
224 packConfig.minCutoutSize))
225 # Test that the cutout size is correct.
226 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize)
227 self.assertTrue(extent == geom.Extent2I(self.cutoutSize,
228 self.cutoutSize))
230 def testCreateCcdDataCutout(self):
231 """Test that the data is being extracted into the CCDData cutout
232 correctly.
233 """
234 packageAlerts = PackageAlertsTask()
236 diaSrcId = 1234
237 ccdData = packageAlerts.createCcdDataCutout(
238 self.exposure,
239 self.exposure.getWcs().getSkyOrigin(),
240 self.exposure.getBBox().getDimensions(),
241 self.exposure.getPhotoCalib(),
242 diaSrcId)
243 calibExposure = self.exposure.getPhotoCalib().calibrateImage(
244 self.exposure.getMaskedImage())
246 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd,
247 self.cutoutWcs.wcs.cd)
248 self.assertFloatsAlmostEqual(ccdData.data,
249 calibExposure.getImage().array)
251 ccdData = packageAlerts.createCcdDataCutout(
252 self.exposure,
253 geom.SpherePoint(0, 0, geom.degrees),
254 self.exposure.getBBox().getDimensions(),
255 self.exposure.getPhotoCalib(),
256 diaSrcId)
257 self.assertTrue(ccdData is None)
259 def testMakeLocalTransformMatrix(self):
260 """Test that the local WCS approximation is correct.
261 """
262 packageAlerts = PackageAlertsTask()
264 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
265 cutout = self.exposure.getCutout(sphPoint,
266 geom.Extent2I(self.cutoutSize,
267 self.cutoutSize))
268 cd = packageAlerts.makeLocalTransformMatrix(
269 cutout.getWcs(), self.center, sphPoint)
270 self.assertFloatsAlmostEqual(
271 cd,
272 cutout.getWcs().getCdMatrix(),
273 rtol=1e-11,
274 atol=1e-11)
276 def testStreamCcdDataToBytes(self):
277 """Test round tripping an CCDData cutout to bytes and back.
278 """
279 packageAlerts = PackageAlertsTask()
281 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
282 cutout = self.exposure.getCutout(sphPoint,
283 geom.Extent2I(self.cutoutSize,
284 self.cutoutSize))
285 cutoutCcdData = CCDData(
286 data=cutout.getImage().array,
287 wcs=self.cutoutWcs,
288 unit="adu")
290 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData)
291 with io.BytesIO(cutoutBytes) as bytesIO:
292 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
293 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data)
295 def testMakeAlertDict(self):
296 """Test stripping data from the various data products and into a
297 dictionary "alert".
298 """
299 packageAlerts = PackageAlertsTask()
300 alertId = 1234
302 for srcIdx, diaSource in self.diaSources.iterrows():
303 sphPoint = geom.SpherePoint(diaSource["ra"],
304 diaSource["dec"],
305 geom.degrees)
306 cutout = self.exposure.getCutout(sphPoint,
307 geom.Extent2I(self.cutoutSize,
308 self.cutoutSize))
309 ccdCutout = packageAlerts.createCcdDataCutout(
310 cutout,
311 sphPoint,
312 geom.Extent2I(self.cutoutSize, self.cutoutSize),
313 cutout.getPhotoCalib(),
314 1234)
315 cutoutBytes = packageAlerts.streamCcdDataToBytes(
316 ccdCutout)
317 objSources = self.diaSourceHistory.loc[srcIdx[0]]
318 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
319 alert = packageAlerts.makeAlertDict(
320 alertId,
321 diaSource,
322 self.diaObjects.loc[srcIdx[0]],
323 objSources,
324 objForcedSources,
325 ccdCutout,
326 ccdCutout,
327 ccdCutout)
328 self.assertEqual(len(alert), 10)
330 self.assertEqual(alert["alertId"], alertId)
331 self.assertEqual(alert["diaSource"], diaSource.to_dict())
332 self.assertEqual(alert["cutoutDifference"],
333 cutoutBytes)
334 self.assertEqual(alert["cutoutScience"],
335 cutoutBytes)
336 self.assertEqual(alert["cutoutTemplate"],
337 cutoutBytes)
339 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
340 def test_produceAlerts_empty_password(self):
341 """ Test that produceAlerts raises if the password is empty or missing.
342 """
343 self.environ['AP_KAFKA_PRODUCER_PASSWORD'] = ""
344 with self.assertRaisesRegex(ValueError, "Kafka password"):
345 packConfig = PackageAlertsConfig(doProduceAlerts=True)
346 PackageAlertsTask(config=packConfig)
348 del self.environ['AP_KAFKA_PRODUCER_PASSWORD']
349 with self.assertRaisesRegex(ValueError, "Kafka password"):
350 packConfig = PackageAlertsConfig(doProduceAlerts=True)
351 PackageAlertsTask(config=packConfig)
353 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
354 def test_produceAlerts_empty_username(self):
355 """ Test that produceAlerts raises if the username is empty or missing.
356 """
357 self.environ['AP_KAFKA_PRODUCER_USERNAME'] = ""
358 with self.assertRaisesRegex(ValueError, "Kafka username"):
359 packConfig = PackageAlertsConfig(doProduceAlerts=True)
360 PackageAlertsTask(config=packConfig)
362 del self.environ['AP_KAFKA_PRODUCER_USERNAME']
363 with self.assertRaisesRegex(ValueError, "Kafka username"):
364 packConfig = PackageAlertsConfig(doProduceAlerts=True)
365 PackageAlertsTask(config=packConfig)
367 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
368 def test_produceAlerts_empty_server(self):
369 """ Test that produceAlerts raises if the server is empty or missing.
370 """
371 self.environ['AP_KAFKA_SERVER'] = ""
372 with self.assertRaisesRegex(ValueError, "Kafka server"):
373 packConfig = PackageAlertsConfig(doProduceAlerts=True)
374 PackageAlertsTask(config=packConfig)
376 del self.environ['AP_KAFKA_SERVER']
377 with self.assertRaisesRegex(ValueError, "Kafka server"):
378 packConfig = PackageAlertsConfig(doProduceAlerts=True)
379 PackageAlertsTask(config=packConfig)
381 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
382 def test_produceAlerts_empty_topic(self):
383 """ Test that produceAlerts raises if the topic is empty or missing.
384 """
385 self.environ['AP_KAFKA_TOPIC'] = ""
386 with self.assertRaisesRegex(ValueError, "Kafka topic"):
387 packConfig = PackageAlertsConfig(doProduceAlerts=True)
388 PackageAlertsTask(config=packConfig)
390 del self.environ['AP_KAFKA_TOPIC']
391 with self.assertRaisesRegex(ValueError, "Kafka topic"):
392 packConfig = PackageAlertsConfig(doProduceAlerts=True)
393 PackageAlertsTask(config=packConfig)
395 @patch('confluent_kafka.Producer')
396 @patch.object(PackageAlertsTask, '_server_check')
397 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
398 def test_produceAlerts_success(self, mock_server_check, mock_producer):
399 """ Test that produceAlerts calls the producer on all provided alerts
400 when the alerts are all under the batch size limit.
401 """
402 packConfig = PackageAlertsConfig(doProduceAlerts=True)
403 packageAlerts = PackageAlertsTask(config=packConfig)
404 alerts = [mock_alert(1), mock_alert(2)]
406 # Create a variable and assign it an instance of the patched kafka producer
407 producer_instance = mock_producer.return_value
408 producer_instance.produce = Mock()
409 producer_instance.flush = Mock()
410 packageAlerts.produceAlerts(alerts, VISIT, DETECTOR)
412 self.assertEqual(mock_server_check.call_count, 1)
413 self.assertEqual(producer_instance.produce.call_count, len(alerts))
414 self.assertEqual(producer_instance.flush.call_count, len(alerts)+1)
416 @patch('confluent_kafka.Producer')
417 @patch.object(PackageAlertsTask, '_server_check')
418 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
419 def test_produceAlerts_one_failure(self, mock_server_check, mock_producer):
420 """ Test that produceAlerts correctly fails on one alert
421 and is writing the failure to disk.
422 """
423 counter = 0
425 def mock_produce(*args, **kwargs):
426 nonlocal counter
427 counter += 1
428 if counter == 2:
429 raise KafkaException
430 else:
431 return
433 packConfig = PackageAlertsConfig(doProduceAlerts=True, doWriteFailedAlerts=True)
434 packageAlerts = PackageAlertsTask(config=packConfig)
436 patcher = patch("builtins.open")
437 patch_open = patcher.start()
438 alerts = [mock_alert(1), mock_alert(2), mock_alert(3)]
440 producer_instance = mock_producer.return_value
441 producer_instance.produce = Mock(side_effect=mock_produce)
442 producer_instance.flush = Mock()
443 packageAlerts.produceAlerts(alerts, VISIT, DETECTOR)
445 self.assertEqual(mock_server_check.call_count, 1)
446 self.assertEqual(producer_instance.produce.call_count, len(alerts))
447 self.assertEqual(patch_open.call_count, 1)
448 self.assertIn(f"{VISIT}_{DETECTOR}_2.avro", patch_open.call_args.args[0])
449 # Because one produce raises, we call flush one fewer times than in the success
450 # test above.
451 self.assertEqual(producer_instance.flush.call_count, len(alerts))
452 patcher.stop()
454 @patch.object(PackageAlertsTask, '_server_check')
455 def testRun_without_produce(self, mock_server_check):
456 """Test the run method of package alerts with produce set to False and
457 doWriteAlerts set to true.
458 """
459 packConfig = PackageAlertsConfig(doWriteAlerts=True)
460 with tempfile.TemporaryDirectory(prefix='alerts') as tempdir:
461 packConfig.alertWriteLocation = tempdir
462 packageAlerts = PackageAlertsTask(config=packConfig)
464 packageAlerts.run(self.diaSources,
465 self.diaObjects,
466 self.diaSourceHistory,
467 self.diaForcedSources,
468 self.exposure,
469 self.exposure,
470 self.exposure)
472 self.assertEqual(mock_server_check.call_count, 0)
474 with open(os.path.join(tempdir, f"{VISIT}_{DETECTOR}.avro"), 'rb') as f:
475 writer_schema, data_stream = \
476 packageAlerts.alertSchema.retrieve_alerts(f)
477 data = list(data_stream)
479 self.assertEqual(len(data), len(self.diaSources))
480 for idx, alert in enumerate(data):
481 for key, value in alert["diaSource"].items():
482 if isinstance(value, float):
483 if np.isnan(self.diaSources.iloc[idx][key]):
484 self.assertTrue(np.isnan(value))
485 else:
486 self.assertAlmostEqual(
487 1 - value / self.diaSources.iloc[idx][key],
488 0.)
489 else:
490 self.assertEqual(value, self.diaSources.iloc[idx][key])
491 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
492 alert["diaSource"]["dec"],
493 geom.degrees)
494 cutout = self.exposure.getCutout(sphPoint,
495 geom.Extent2I(self.cutoutSize,
496 self.cutoutSize))
497 ccdCutout = packageAlerts.createCcdDataCutout(
498 cutout,
499 sphPoint,
500 geom.Extent2I(self.cutoutSize, self.cutoutSize),
501 cutout.getPhotoCalib(),
502 1234)
503 self.assertEqual(alert["cutoutDifference"],
504 packageAlerts.streamCcdDataToBytes(ccdCutout))
506 @patch.object(PackageAlertsTask, 'produceAlerts')
507 @patch('confluent_kafka.Producer')
508 @patch.object(PackageAlertsTask, '_server_check')
509 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
510 def testRun_with_produce(self, mock_produceAlerts, mock_server_check, mock_producer):
511 """Test that packageAlerts calls produceAlerts when doProduceAlerts
512 is set to True.
513 """
514 packConfig = PackageAlertsConfig(doProduceAlerts=True)
515 packageAlerts = PackageAlertsTask(config=packConfig)
517 packageAlerts.run(self.diaSources,
518 self.diaObjects,
519 self.diaSourceHistory,
520 self.diaForcedSources,
521 self.exposure,
522 self.exposure,
523 self.exposure)
524 self.assertEqual(mock_server_check.call_count, 1)
525 self.assertEqual(mock_produceAlerts.call_count, 1)
527 def test_serialize_alert_round_trip(self):
528 """Test that values in the alert packet exactly round trip.
529 """
530 packClass = PackageAlertsConfig()
531 packageAlerts = PackageAlertsTask(config=packClass)
533 alert = mock_alert(1)
534 serialized = PackageAlertsTask._serializeAlert(packageAlerts, alert)
535 deserialized = _deserialize_alert(serialized)
537 for field in alert['diaSource']:
538 self.assertEqual(alert['diaSource'][field], deserialized['diaSource'][field])
539 self.assertEqual(1, deserialized["alertId"])
541 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
542 def test_server_check(self):
544 with self.assertRaisesRegex(KafkaException, "_TRANSPORT"):
545 packConfig = PackageAlertsConfig(doProduceAlerts=True)
546 PackageAlertsTask(config=packConfig)
549class MemoryTester(lsst.utils.tests.MemoryTestCase):
550 pass
553def setup_module(module):
554 lsst.utils.tests.init()
557if __name__ == "__main__": 557 ↛ 558line 557 didn't jump to line 558, because the condition on line 557 was never true
558 lsst.utils.tests.init()
559 unittest.main()