Coverage for tests/test_packageAlerts.py: 30%
268 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-19 05:12 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-19 05:12 -0700
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import io
23import os
25import numpy as np
26import pandas as pd
27import tempfile
28import unittest
29from unittest.mock import patch, Mock
30from astropy import wcs
31from astropy.nddata import CCDData
32import fastavro
33try:
34 import confluent_kafka
35 from confluent_kafka import KafkaException
36except ImportError:
37 confluent_kafka = None
39import lsst.alert.packet as alertPack
40from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask
41from lsst.afw.cameraGeom.testUtils import DetectorWrapper
42import lsst.afw.image as afwImage
43import lsst.daf.base as dafBase
44from lsst.dax.apdb import Apdb, ApdbSql
45import lsst.geom as geom
46import lsst.meas.base.tests
47from lsst.sphgeom import Box
48import lsst.utils.tests
49import utils_tests
52def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
53 """Run object and source catalogs through the Apdb to get the correct
54 table schemas.
56 Parameters
57 ----------
58 objects : `pandas.DataFrame`
59 Set of test DiaObjects to round trip.
60 sources : `pandas.DataFrame`
61 Set of test DiaSources to round trip.
62 forcedSources : `pandas.DataFrame`
63 Set of test DiaForcedSources to round trip.
64 dateTime : `astropy.time.Time`
65 Time for the Apdb.
67 Returns
68 -------
69 objects : `pandas.DataFrame`
70 Round tripped objects.
71 sources : `pandas.DataFrame`
72 Round tripped sources.
73 """
74 with tempfile.NamedTemporaryFile() as tmpFile:
75 apdbConfig = ApdbSql.init_database(db_url="sqlite:///" + tmpFile.name)
76 apdb = Apdb.from_config(apdbConfig)
78 wholeSky = Box.full()
79 diaObjects = pd.concat([apdb.getDiaObjects(wholeSky), objects])
80 diaSources = pd.concat(
81 [apdb.getDiaSources(wholeSky, [], dateTime), sources])
82 diaForcedSources = pd.concat(
83 [apdb.getDiaForcedSources(wholeSky, [], dateTime), forcedSources])
85 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources)
87 diaObjects = apdb.getDiaObjects(wholeSky)
88 diaSources = apdb.getDiaSources(wholeSky,
89 np.unique(diaObjects["diaObjectId"]),
90 dateTime)
91 diaForcedSources = apdb.getDiaForcedSources(
92 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime)
94 diaObjects.set_index("diaObjectId", drop=False, inplace=True)
95 diaSources.set_index(["diaObjectId", "band", "diaSourceId"],
96 drop=False,
97 inplace=True)
98 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True)
100 return (diaObjects, diaSources, diaForcedSources)
103def mock_alert(alert_id):
104 """Generate a minimal mock alert.
105 """
106 return {
107 "alertId": alert_id,
108 "diaSource": {
109 "midpointMjdTai": 5,
110 "diaSourceId": 4,
111 "ccdVisitId": 2,
112 "band": 'g',
113 "ra": 12.5,
114 "dec": -16.9,
115 # These types are 32-bit floats in the avro schema, so we have to
116 # make them that type here, so that they round trip appropriately.
117 "x": np.float32(15.7),
118 "y": np.float32(89.8),
119 "apFlux": np.float32(54.85),
120 "apFluxErr": np.float32(70.0),
121 "snr": np.float32(6.7),
122 "psfFlux": np.float32(700.0),
123 "psfFluxErr": np.float32(90.0),
124 "flags": 12345,
125 }
126 }
129def _deserialize_alert(alert_bytes):
130 """Deserialize an alert message from Kafka.
132 Parameters
133 ----------
134 alert_bytes : `bytes`
135 Binary-encoding serialized Avro alert, including Confluent Wire
136 Format prefix.
138 Returns
139 -------
140 alert : `dict`
141 An alert payload.
142 """
143 schema = alertPack.Schema.from_uri(str(alertPack.get_uri_to_latest_schema()))
144 content_bytes = io.BytesIO(alert_bytes[5:])
146 return fastavro.schemaless_reader(content_bytes, schema.definition)
149class TestPackageAlerts(lsst.utils.tests.TestCase):
151 def setUp(self):
152 patcher = patch.dict(os.environ, {"AP_KAFKA_PRODUCER_PASSWORD": "fake_password",
153 "AP_KAFKA_PRODUCER_USERNAME": "fake_username",
154 "AP_KAFKA_SERVER": "fake_server",
155 "AP_KAFKA_TOPIC": "fake_topic"})
156 self.environ = patcher.start()
157 self.addCleanup(patcher.stop)
158 np.random.seed(1234)
159 self.cutoutSize = 35
160 self.center = lsst.geom.Point2D(50.1, 49.8)
161 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
162 lsst.geom.Extent2I(140, 160))
163 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
164 self.dataset.addSource(100000.0, self.center)
165 exposure, catalog = self.dataset.realize(
166 10.0,
167 self.dataset.makeMinimalSchema(),
168 randomSeed=0)
169 self.exposure = exposure
170 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector
171 self.exposure.setDetector(detector)
173 visit = afwImage.VisitInfo(
174 exposureTime=200.,
175 date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
176 dafBase.DateTime.Timescale.TAI))
177 self.exposure.info.id = 1234
178 self.exposure.info.setVisitInfo(visit)
180 self.exposure.setFilter(
181 afwImage.FilterLabel(band='g', physical="g.MP9401"))
183 diaObjects = utils_tests.makeDiaObjects(2, self.exposure)
184 diaSourceHistory = utils_tests.makeDiaSources(10,
185 diaObjects[
186 "diaObjectId"],
187 self.exposure)
188 diaForcedSources = utils_tests.makeDiaForcedSources(10,
189 diaObjects[
190 "diaObjectId"],
191 self.exposure)
192 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb(
193 diaObjects,
194 diaSourceHistory,
195 diaForcedSources,
196 self.exposure.visitInfo.date.toAstropy())
197 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True)
198 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True)
199 self.diaForcedSources.replace(to_replace=[None], value=np.nan,
200 inplace=True)
201 diaSourceHistory["programId"] = 0
203 self.diaSources = diaSourceHistory.loc[[(1, "g", 9), (2, "g", 10)], :]
204 self.diaSources["bboxSize"] = self.cutoutSize
205 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9),
206 (2, "g", 10)])
208 self.cutoutWcs = wcs.WCS(naxis=2)
209 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]]
210 self.cutoutWcs.wcs.crval = [
211 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(),
212 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()]
213 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
214 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
216 def testCreateExtent(self):
217 """Test the extent creation for the cutout bbox.
218 """
219 packConfig = PackageAlertsConfig()
220 # Just create a minimum less than the default cutout.
221 packConfig.minCutoutSize = self.cutoutSize - 5
222 packageAlerts = PackageAlertsTask(config=packConfig)
223 extent = packageAlerts.createDiaSourceExtent(
224 packConfig.minCutoutSize - 5)
225 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize,
226 packConfig.minCutoutSize))
227 # Test that the cutout size is correct.
228 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize)
229 self.assertTrue(extent == geom.Extent2I(self.cutoutSize,
230 self.cutoutSize))
232 def testCreateCcdDataCutout(self):
233 """Test that the data is being extracted into the CCDData cutout
234 correctly.
235 """
236 packageAlerts = PackageAlertsTask()
238 diaSrcId = 1234
239 ccdData = packageAlerts.createCcdDataCutout(
240 self.exposure,
241 self.exposure.getWcs().getSkyOrigin(),
242 self.exposure.getBBox().getDimensions(),
243 self.exposure.getPhotoCalib(),
244 diaSrcId)
245 calibExposure = self.exposure.getPhotoCalib().calibrateImage(
246 self.exposure.getMaskedImage())
248 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd,
249 self.cutoutWcs.wcs.cd)
250 self.assertFloatsAlmostEqual(ccdData.data,
251 calibExposure.getImage().array)
253 ccdData = packageAlerts.createCcdDataCutout(
254 self.exposure,
255 geom.SpherePoint(0, 0, geom.degrees),
256 self.exposure.getBBox().getDimensions(),
257 self.exposure.getPhotoCalib(),
258 diaSrcId)
259 self.assertTrue(ccdData is None)
261 def testMakeLocalTransformMatrix(self):
262 """Test that the local WCS approximation is correct.
263 """
264 packageAlerts = PackageAlertsTask()
266 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
267 cutout = self.exposure.getCutout(sphPoint,
268 geom.Extent2I(self.cutoutSize,
269 self.cutoutSize))
270 cd = packageAlerts.makeLocalTransformMatrix(
271 cutout.getWcs(), self.center, sphPoint)
272 self.assertFloatsAlmostEqual(
273 cd,
274 cutout.getWcs().getCdMatrix(),
275 rtol=1e-11,
276 atol=1e-11)
278 def testStreamCcdDataToBytes(self):
279 """Test round tripping an CCDData cutout to bytes and back.
280 """
281 packageAlerts = PackageAlertsTask()
283 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
284 cutout = self.exposure.getCutout(sphPoint,
285 geom.Extent2I(self.cutoutSize,
286 self.cutoutSize))
287 cutoutCcdData = CCDData(
288 data=cutout.getImage().array,
289 wcs=self.cutoutWcs,
290 unit="adu")
292 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData)
293 with io.BytesIO(cutoutBytes) as bytesIO:
294 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
295 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data)
297 def testMakeAlertDict(self):
298 """Test stripping data from the various data products and into a
299 dictionary "alert".
300 """
301 packageAlerts = PackageAlertsTask()
302 alertId = 1234
304 for srcIdx, diaSource in self.diaSources.iterrows():
305 sphPoint = geom.SpherePoint(diaSource["ra"],
306 diaSource["dec"],
307 geom.degrees)
308 cutout = self.exposure.getCutout(sphPoint,
309 geom.Extent2I(self.cutoutSize,
310 self.cutoutSize))
311 ccdCutout = packageAlerts.createCcdDataCutout(
312 cutout,
313 sphPoint,
314 geom.Extent2I(self.cutoutSize, self.cutoutSize),
315 cutout.getPhotoCalib(),
316 1234)
317 cutoutBytes = packageAlerts.streamCcdDataToBytes(
318 ccdCutout)
319 objSources = self.diaSourceHistory.loc[srcIdx[0]]
320 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
321 alert = packageAlerts.makeAlertDict(
322 alertId,
323 diaSource,
324 self.diaObjects.loc[srcIdx[0]],
325 objSources,
326 objForcedSources,
327 ccdCutout,
328 ccdCutout,
329 ccdCutout)
330 self.assertEqual(len(alert), 10)
332 self.assertEqual(alert["alertId"], alertId)
333 self.assertEqual(alert["diaSource"], diaSource.to_dict())
334 self.assertEqual(alert["cutoutDifference"],
335 cutoutBytes)
336 self.assertEqual(alert["cutoutScience"],
337 cutoutBytes)
338 self.assertEqual(alert["cutoutTemplate"],
339 cutoutBytes)
341 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
342 def test_produceAlerts_empty_password(self):
343 """ Test that produceAlerts raises if the password is empty or missing.
344 """
345 self.environ['AP_KAFKA_PRODUCER_PASSWORD'] = ""
346 with self.assertRaisesRegex(ValueError, "Kafka password"):
347 packConfig = PackageAlertsConfig(doProduceAlerts=True)
348 PackageAlertsTask(config=packConfig)
350 del self.environ['AP_KAFKA_PRODUCER_PASSWORD']
351 with self.assertRaisesRegex(ValueError, "Kafka password"):
352 packConfig = PackageAlertsConfig(doProduceAlerts=True)
353 PackageAlertsTask(config=packConfig)
355 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
356 def test_produceAlerts_empty_username(self):
357 """ Test that produceAlerts raises if the username is empty or missing.
358 """
359 self.environ['AP_KAFKA_PRODUCER_USERNAME'] = ""
360 with self.assertRaisesRegex(ValueError, "Kafka username"):
361 packConfig = PackageAlertsConfig(doProduceAlerts=True)
362 PackageAlertsTask(config=packConfig)
364 del self.environ['AP_KAFKA_PRODUCER_USERNAME']
365 with self.assertRaisesRegex(ValueError, "Kafka username"):
366 packConfig = PackageAlertsConfig(doProduceAlerts=True)
367 PackageAlertsTask(config=packConfig)
369 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
370 def test_produceAlerts_empty_server(self):
371 """ Test that produceAlerts raises if the server is empty or missing.
372 """
373 self.environ['AP_KAFKA_SERVER'] = ""
374 with self.assertRaisesRegex(ValueError, "Kafka server"):
375 packConfig = PackageAlertsConfig(doProduceAlerts=True)
376 PackageAlertsTask(config=packConfig)
378 del self.environ['AP_KAFKA_SERVER']
379 with self.assertRaisesRegex(ValueError, "Kafka server"):
380 packConfig = PackageAlertsConfig(doProduceAlerts=True)
381 PackageAlertsTask(config=packConfig)
383 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
384 def test_produceAlerts_empty_topic(self):
385 """ Test that produceAlerts raises if the topic is empty or missing.
386 """
387 self.environ['AP_KAFKA_TOPIC'] = ""
388 with self.assertRaisesRegex(ValueError, "Kafka topic"):
389 packConfig = PackageAlertsConfig(doProduceAlerts=True)
390 PackageAlertsTask(config=packConfig)
392 del self.environ['AP_KAFKA_TOPIC']
393 with self.assertRaisesRegex(ValueError, "Kafka topic"):
394 packConfig = PackageAlertsConfig(doProduceAlerts=True)
395 PackageAlertsTask(config=packConfig)
397 @patch('confluent_kafka.Producer')
398 @patch.object(PackageAlertsTask, '_server_check')
399 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
400 def test_produceAlerts_success(self, mock_server_check, mock_producer):
401 """ Test that produceAlerts calls the producer on all provided alerts
402 when the alerts are all under the batch size limit.
403 """
404 packConfig = PackageAlertsConfig(doProduceAlerts=True)
405 packageAlerts = PackageAlertsTask(config=packConfig)
406 alerts = [mock_alert(1), mock_alert(2)]
407 ccdVisitId = 123
409 # Create a variable and assign it an instance of the patched kafka producer
410 producer_instance = mock_producer.return_value
411 producer_instance.produce = Mock()
412 producer_instance.flush = Mock()
413 packageAlerts.produceAlerts(alerts, ccdVisitId)
415 self.assertEqual(mock_server_check.call_count, 1)
416 self.assertEqual(producer_instance.produce.call_count, len(alerts))
417 self.assertEqual(producer_instance.flush.call_count, len(alerts)+1)
419 @patch('confluent_kafka.Producer')
420 @patch.object(PackageAlertsTask, '_server_check')
421 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
422 def test_produceAlerts_one_failure(self, mock_server_check, mock_producer):
423 """ Test that produceAlerts correctly fails on one alert
424 and is writing the failure to disk.
425 """
426 counter = 0
428 def mock_produce(*args, **kwargs):
429 nonlocal counter
430 counter += 1
431 if counter == 2:
432 raise KafkaException
433 else:
434 return
436 packConfig = PackageAlertsConfig(doProduceAlerts=True, doWriteFailedAlerts=True)
437 packageAlerts = PackageAlertsTask(config=packConfig)
439 patcher = patch("builtins.open")
440 patch_open = patcher.start()
441 alerts = [mock_alert(1), mock_alert(2), mock_alert(3)]
442 ccdVisitId = 123
444 producer_instance = mock_producer.return_value
445 producer_instance.produce = Mock(side_effect=mock_produce)
446 producer_instance.flush = Mock()
448 packageAlerts.produceAlerts(alerts, ccdVisitId)
450 self.assertEqual(mock_server_check.call_count, 1)
451 self.assertEqual(producer_instance.produce.call_count, len(alerts))
452 self.assertEqual(patch_open.call_count, 1)
453 self.assertIn("123_2.avro", patch_open.call_args.args[0])
454 # Because one produce raises, we call flush one fewer times than in the success
455 # test above.
456 self.assertEqual(producer_instance.flush.call_count, len(alerts))
457 patcher.stop()
459 @patch.object(PackageAlertsTask, '_server_check')
460 def testRun_without_produce(self, mock_server_check):
461 """Test the run method of package alerts with produce set to False and
462 doWriteAlerts set to true.
463 """
464 packConfig = PackageAlertsConfig(doWriteAlerts=True)
465 with tempfile.TemporaryDirectory(prefix='alerts') as tempdir:
466 packConfig.alertWriteLocation = tempdir
467 packageAlerts = PackageAlertsTask(config=packConfig)
469 packageAlerts.run(self.diaSources,
470 self.diaObjects,
471 self.diaSourceHistory,
472 self.diaForcedSources,
473 self.exposure,
474 self.exposure,
475 self.exposure)
477 self.assertEqual(mock_server_check.call_count, 0)
479 ccdVisitId = self.exposure.info.id
480 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f:
481 writer_schema, data_stream = \
482 packageAlerts.alertSchema.retrieve_alerts(f)
483 data = list(data_stream)
485 self.assertEqual(len(data), len(self.diaSources))
486 for idx, alert in enumerate(data):
487 for key, value in alert["diaSource"].items():
488 if isinstance(value, float):
489 if np.isnan(self.diaSources.iloc[idx][key]):
490 self.assertTrue(np.isnan(value))
491 else:
492 self.assertAlmostEqual(
493 1 - value / self.diaSources.iloc[idx][key],
494 0.)
495 else:
496 self.assertEqual(value, self.diaSources.iloc[idx][key])
497 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
498 alert["diaSource"]["dec"],
499 geom.degrees)
500 cutout = self.exposure.getCutout(sphPoint,
501 geom.Extent2I(self.cutoutSize,
502 self.cutoutSize))
503 ccdCutout = packageAlerts.createCcdDataCutout(
504 cutout,
505 sphPoint,
506 geom.Extent2I(self.cutoutSize, self.cutoutSize),
507 cutout.getPhotoCalib(),
508 1234)
509 self.assertEqual(alert["cutoutDifference"],
510 packageAlerts.streamCcdDataToBytes(ccdCutout))
512 @patch.object(PackageAlertsTask, 'produceAlerts')
513 @patch('confluent_kafka.Producer')
514 @patch.object(PackageAlertsTask, '_server_check')
515 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
516 def testRun_with_produce(self, mock_produceAlerts, mock_server_check, mock_producer):
517 """Test that packageAlerts calls produceAlerts when doProduceAlerts
518 is set to True.
519 """
520 packConfig = PackageAlertsConfig(doProduceAlerts=True)
521 packageAlerts = PackageAlertsTask(config=packConfig)
523 packageAlerts.run(self.diaSources,
524 self.diaObjects,
525 self.diaSourceHistory,
526 self.diaForcedSources,
527 self.exposure,
528 self.exposure,
529 self.exposure)
530 self.assertEqual(mock_server_check.call_count, 1)
531 self.assertEqual(mock_produceAlerts.call_count, 1)
533 def test_serialize_alert_round_trip(self):
534 """Test that values in the alert packet exactly round trip.
535 """
536 packClass = PackageAlertsConfig()
537 packageAlerts = PackageAlertsTask(config=packClass)
539 alert = mock_alert(1)
540 serialized = PackageAlertsTask._serializeAlert(packageAlerts, alert)
541 deserialized = _deserialize_alert(serialized)
543 for field in alert['diaSource']:
544 self.assertEqual(alert['diaSource'][field], deserialized['diaSource'][field])
545 self.assertEqual(1, deserialized["alertId"])
547 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
548 def test_server_check(self):
550 with self.assertRaisesRegex(KafkaException, "_TRANSPORT"):
551 packConfig = PackageAlertsConfig(doProduceAlerts=True)
552 PackageAlertsTask(config=packConfig)
555class MemoryTester(lsst.utils.tests.MemoryTestCase):
556 pass
559def setup_module(module):
560 lsst.utils.tests.init()
563if __name__ == "__main__": 563 ↛ 564line 563 didn't jump to line 564, because the condition on line 563 was never true
564 lsst.utils.tests.init()
565 unittest.main()