Coverage for tests/test_packageAlerts.py: 30%
274 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-29 11:37 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-29 11:37 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import io
23import os
25import numpy as np
26import pandas as pd
27import shutil
28import tempfile
29import unittest
30from unittest.mock import patch, Mock
31from astropy import wcs
32from astropy.nddata import CCDData
33import fastavro
34try:
35 import confluent_kafka
36 from confluent_kafka import KafkaException
37except ImportError:
38 confluent_kafka = None
40import lsst.alert.packet as alertPack
41from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask
42from lsst.afw.cameraGeom.testUtils import DetectorWrapper
43import lsst.afw.image as afwImage
44import lsst.daf.base as dafBase
45from lsst.dax.apdb import Apdb, ApdbSql, ApdbSqlConfig
46import lsst.geom as geom
47import lsst.meas.base.tests
48from lsst.sphgeom import Box
49import lsst.utils.tests
50import utils_tests
53def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
54 """Run object and source catalogs through the Apdb to get the correct
55 table schemas.
57 Parameters
58 ----------
59 objects : `pandas.DataFrame`
60 Set of test DiaObjects to round trip.
61 sources : `pandas.DataFrame`
62 Set of test DiaSources to round trip.
63 forcedSources : `pandas.DataFrame`
64 Set of test DiaForcedSources to round trip.
65 dateTime : `astropy.time.Time`
66 Time for the Apdb.
68 Returns
69 -------
70 objects : `pandas.DataFrame`
71 Round tripped objects.
72 sources : `pandas.DataFrame`
73 Round tripped sources.
74 """
75 tmpFile = tempfile.NamedTemporaryFile()
77 apdbConfig = ApdbSqlConfig()
78 apdbConfig.db_url = "sqlite:///" + tmpFile.name
79 apdbConfig.dia_object_index = "baseline"
80 apdbConfig.dia_object_columns = []
82 Apdb.makeSchema(apdbConfig)
83 apdb = ApdbSql(config=apdbConfig)
85 wholeSky = Box.full()
86 diaObjects = pd.concat([apdb.getDiaObjects(wholeSky), objects])
87 diaSources = pd.concat(
88 [apdb.getDiaSources(wholeSky, [], dateTime), sources])
89 diaForcedSources = pd.concat(
90 [apdb.getDiaForcedSources(wholeSky, [], dateTime), forcedSources])
92 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources)
94 diaObjects = apdb.getDiaObjects(wholeSky)
95 diaSources = apdb.getDiaSources(wholeSky,
96 np.unique(diaObjects["diaObjectId"]),
97 dateTime)
98 diaForcedSources = apdb.getDiaForcedSources(
99 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime)
101 diaObjects.set_index("diaObjectId", drop=False, inplace=True)
102 diaSources.set_index(["diaObjectId", "band", "diaSourceId"],
103 drop=False,
104 inplace=True)
105 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True)
107 return (diaObjects, diaSources, diaForcedSources)
110def mock_alert(alert_id):
111 """Generate a minimal mock alert.
112 """
113 return {
114 "alertId": alert_id,
115 "diaSource": {
116 "midpointMjdTai": 5,
117 "diaSourceId": 4,
118 "ccdVisitId": 2,
119 "band": 'g',
120 "ra": 12.5,
121 "dec": -16.9,
122 # These types are 32-bit floats in the avro schema, so we have to
123 # make them that type here, so that they round trip appropriately.
124 "x": np.float32(15.7),
125 "y": np.float32(89.8),
126 "apFlux": np.float32(54.85),
127 "apFluxErr": np.float32(70.0),
128 "snr": np.float32(6.7),
129 "psfFlux": np.float32(700.0),
130 "psfFluxErr": np.float32(90.0),
131 "flags": 12345,
132 }
133 }
136def _deserialize_alert(alert_bytes):
137 """Deserialize an alert message from Kafka.
139 Parameters
140 ----------
141 alert_bytes : `bytes`
142 Binary-encoding serialized Avro alert, including Confluent Wire
143 Format prefix.
145 Returns
146 -------
147 alert : `dict`
148 An alert payload.
149 """
150 schema = alertPack.Schema.from_uri(str(alertPack.get_uri_to_latest_schema()))
151 content_bytes = io.BytesIO(alert_bytes[5:])
153 return fastavro.schemaless_reader(content_bytes, schema.definition)
156class TestPackageAlerts(lsst.utils.tests.TestCase):
158 def setUp(self):
159 patcher = patch.dict(os.environ, {"AP_KAFKA_PRODUCER_PASSWORD": "fake_password",
160 "AP_KAFKA_PRODUCER_USERNAME": "fake_username",
161 "AP_KAFKA_SERVER": "fake_server",
162 "AP_KAFKA_TOPIC": "fake_topic"})
163 self.environ = patcher.start()
164 self.addCleanup(patcher.stop)
165 np.random.seed(1234)
166 self.cutoutSize = 35
167 self.center = lsst.geom.Point2D(50.1, 49.8)
168 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
169 lsst.geom.Extent2I(140, 160))
170 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
171 self.dataset.addSource(100000.0, self.center)
172 exposure, catalog = self.dataset.realize(
173 10.0,
174 self.dataset.makeMinimalSchema(),
175 randomSeed=0)
176 self.exposure = exposure
177 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector
178 self.exposure.setDetector(detector)
180 visit = afwImage.VisitInfo(
181 exposureTime=200.,
182 date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
183 dafBase.DateTime.Timescale.TAI))
184 self.exposure.info.id = 1234
185 self.exposure.info.setVisitInfo(visit)
187 self.exposure.setFilter(
188 afwImage.FilterLabel(band='g', physical="g.MP9401"))
190 diaObjects = utils_tests.makeDiaObjects(2, self.exposure)
191 diaSourceHistory = utils_tests.makeDiaSources(10,
192 diaObjects[
193 "diaObjectId"],
194 self.exposure)
195 diaForcedSources = utils_tests.makeDiaForcedSources(10,
196 diaObjects[
197 "diaObjectId"],
198 self.exposure)
199 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb(
200 diaObjects,
201 diaSourceHistory,
202 diaForcedSources,
203 self.exposure.visitInfo.date.toAstropy())
204 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True)
205 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True)
206 self.diaForcedSources.replace(to_replace=[None], value=np.nan,
207 inplace=True)
208 diaSourceHistory["programId"] = 0
210 self.diaSources = diaSourceHistory.loc[[(1, "g", 9), (2, "g", 10)], :]
211 self.diaSources["bboxSize"] = self.cutoutSize
212 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9),
213 (2, "g", 10)])
215 self.cutoutWcs = wcs.WCS(naxis=2)
216 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]]
217 self.cutoutWcs.wcs.crval = [
218 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(),
219 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()]
220 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
221 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
223 def testCreateExtent(self):
224 """Test the extent creation for the cutout bbox.
225 """
226 packConfig = PackageAlertsConfig()
227 # Just create a minimum less than the default cutout.
228 packConfig.minCutoutSize = self.cutoutSize - 5
229 packageAlerts = PackageAlertsTask(config=packConfig)
230 extent = packageAlerts.createDiaSourceExtent(
231 packConfig.minCutoutSize - 5)
232 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize,
233 packConfig.minCutoutSize))
234 # Test that the cutout size is correct.
235 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize)
236 self.assertTrue(extent == geom.Extent2I(self.cutoutSize,
237 self.cutoutSize))
239 def testCreateCcdDataCutout(self):
240 """Test that the data is being extracted into the CCDData cutout
241 correctly.
242 """
243 packageAlerts = PackageAlertsTask()
245 diaSrcId = 1234
246 ccdData = packageAlerts.createCcdDataCutout(
247 self.exposure,
248 self.exposure.getWcs().getSkyOrigin(),
249 self.exposure.getBBox().getDimensions(),
250 self.exposure.getPhotoCalib(),
251 diaSrcId)
252 calibExposure = self.exposure.getPhotoCalib().calibrateImage(
253 self.exposure.getMaskedImage())
255 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd,
256 self.cutoutWcs.wcs.cd)
257 self.assertFloatsAlmostEqual(ccdData.data,
258 calibExposure.getImage().array)
260 ccdData = packageAlerts.createCcdDataCutout(
261 self.exposure,
262 geom.SpherePoint(0, 0, geom.degrees),
263 self.exposure.getBBox().getDimensions(),
264 self.exposure.getPhotoCalib(),
265 diaSrcId)
266 self.assertTrue(ccdData is None)
268 def testMakeLocalTransformMatrix(self):
269 """Test that the local WCS approximation is correct.
270 """
271 packageAlerts = PackageAlertsTask()
273 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
274 cutout = self.exposure.getCutout(sphPoint,
275 geom.Extent2I(self.cutoutSize,
276 self.cutoutSize))
277 cd = packageAlerts.makeLocalTransformMatrix(
278 cutout.getWcs(), self.center, sphPoint)
279 self.assertFloatsAlmostEqual(
280 cd,
281 cutout.getWcs().getCdMatrix(),
282 rtol=1e-11,
283 atol=1e-11)
285 def testStreamCcdDataToBytes(self):
286 """Test round tripping an CCDData cutout to bytes and back.
287 """
288 packageAlerts = PackageAlertsTask()
290 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
291 cutout = self.exposure.getCutout(sphPoint,
292 geom.Extent2I(self.cutoutSize,
293 self.cutoutSize))
294 cutoutCcdData = CCDData(
295 data=cutout.getImage().array,
296 wcs=self.cutoutWcs,
297 unit="adu")
299 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData)
300 with io.BytesIO(cutoutBytes) as bytesIO:
301 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
302 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data)
304 def testMakeAlertDict(self):
305 """Test stripping data from the various data products and into a
306 dictionary "alert".
307 """
308 packageAlerts = PackageAlertsTask()
309 alertId = 1234
311 for srcIdx, diaSource in self.diaSources.iterrows():
312 sphPoint = geom.SpherePoint(diaSource["ra"],
313 diaSource["dec"],
314 geom.degrees)
315 cutout = self.exposure.getCutout(sphPoint,
316 geom.Extent2I(self.cutoutSize,
317 self.cutoutSize))
318 ccdCutout = packageAlerts.createCcdDataCutout(
319 cutout,
320 sphPoint,
321 geom.Extent2I(self.cutoutSize, self.cutoutSize),
322 cutout.getPhotoCalib(),
323 1234)
324 cutoutBytes = packageAlerts.streamCcdDataToBytes(
325 ccdCutout)
326 objSources = self.diaSourceHistory.loc[srcIdx[0]]
327 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
328 alert = packageAlerts.makeAlertDict(
329 alertId,
330 diaSource,
331 self.diaObjects.loc[srcIdx[0]],
332 objSources,
333 objForcedSources,
334 ccdCutout,
335 ccdCutout,
336 ccdCutout)
337 self.assertEqual(len(alert), 10)
339 self.assertEqual(alert["alertId"], alertId)
340 self.assertEqual(alert["diaSource"], diaSource.to_dict())
341 self.assertEqual(alert["cutoutDifference"],
342 cutoutBytes)
343 self.assertEqual(alert["cutoutScience"],
344 cutoutBytes)
345 self.assertEqual(alert["cutoutTemplate"],
346 cutoutBytes)
348 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
349 def test_produceAlerts_empty_password(self):
350 """ Test that produceAlerts raises if the password is empty or missing.
351 """
352 self.environ['AP_KAFKA_PRODUCER_PASSWORD'] = ""
353 with self.assertRaisesRegex(ValueError, "Kafka password"):
354 packConfig = PackageAlertsConfig(doProduceAlerts=True)
355 PackageAlertsTask(config=packConfig)
357 del self.environ['AP_KAFKA_PRODUCER_PASSWORD']
358 with self.assertRaisesRegex(ValueError, "Kafka password"):
359 packConfig = PackageAlertsConfig(doProduceAlerts=True)
360 PackageAlertsTask(config=packConfig)
362 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
363 def test_produceAlerts_empty_username(self):
364 """ Test that produceAlerts raises if the username is empty or missing.
365 """
366 self.environ['AP_KAFKA_PRODUCER_USERNAME'] = ""
367 with self.assertRaisesRegex(ValueError, "Kafka username"):
368 packConfig = PackageAlertsConfig(doProduceAlerts=True)
369 PackageAlertsTask(config=packConfig)
371 del self.environ['AP_KAFKA_PRODUCER_USERNAME']
372 with self.assertRaisesRegex(ValueError, "Kafka username"):
373 packConfig = PackageAlertsConfig(doProduceAlerts=True)
374 PackageAlertsTask(config=packConfig)
376 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
377 def test_produceAlerts_empty_server(self):
378 """ Test that produceAlerts raises if the server is empty or missing.
379 """
380 self.environ['AP_KAFKA_SERVER'] = ""
381 with self.assertRaisesRegex(ValueError, "Kafka server"):
382 packConfig = PackageAlertsConfig(doProduceAlerts=True)
383 PackageAlertsTask(config=packConfig)
385 del self.environ['AP_KAFKA_SERVER']
386 with self.assertRaisesRegex(ValueError, "Kafka server"):
387 packConfig = PackageAlertsConfig(doProduceAlerts=True)
388 PackageAlertsTask(config=packConfig)
390 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
391 def test_produceAlerts_empty_topic(self):
392 """ Test that produceAlerts raises if the topic is empty or missing.
393 """
394 self.environ['AP_KAFKA_TOPIC'] = ""
395 with self.assertRaisesRegex(ValueError, "Kafka topic"):
396 packConfig = PackageAlertsConfig(doProduceAlerts=True)
397 PackageAlertsTask(config=packConfig)
399 del self.environ['AP_KAFKA_TOPIC']
400 with self.assertRaisesRegex(ValueError, "Kafka topic"):
401 packConfig = PackageAlertsConfig(doProduceAlerts=True)
402 PackageAlertsTask(config=packConfig)
404 @patch('confluent_kafka.Producer')
405 @patch.object(PackageAlertsTask, '_server_check')
406 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
407 def test_produceAlerts_success(self, mock_server_check, mock_producer):
408 """ Test that produceAlerts calls the producer on all provided alerts
409 when the alerts are all under the batch size limit.
410 """
411 packConfig = PackageAlertsConfig(doProduceAlerts=True)
412 packageAlerts = PackageAlertsTask(config=packConfig)
413 alerts = [mock_alert(1), mock_alert(2)]
414 ccdVisitId = 123
416 # Create a variable and assign it an instance of the patched kafka producer
417 producer_instance = mock_producer.return_value
418 producer_instance.produce = Mock()
419 producer_instance.flush = Mock()
420 packageAlerts.produceAlerts(alerts, ccdVisitId)
422 self.assertEqual(mock_server_check.call_count, 2)
423 self.assertEqual(producer_instance.produce.call_count, len(alerts))
424 self.assertEqual(producer_instance.flush.call_count, len(alerts)+1)
426 @patch('confluent_kafka.Producer')
427 @patch.object(PackageAlertsTask, '_server_check')
428 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
429 def test_produceAlerts_one_failure(self, mock_server_check, mock_producer):
430 """ Test that produceAlerts correctly fails on one alert
431 and is writing the failure to disk.
432 """
433 counter = 0
435 def mock_produce(*args, **kwargs):
436 nonlocal counter
437 counter += 1
438 if counter == 2:
439 raise KafkaException
440 else:
441 return
443 packConfig = PackageAlertsConfig(doProduceAlerts=True, doWriteFailedAlerts=True)
444 packageAlerts = PackageAlertsTask(config=packConfig)
446 patcher = patch("builtins.open")
447 patch_open = patcher.start()
448 alerts = [mock_alert(1), mock_alert(2), mock_alert(3)]
449 ccdVisitId = 123
451 producer_instance = mock_producer.return_value
452 producer_instance.produce = Mock(side_effect=mock_produce)
453 producer_instance.flush = Mock()
455 packageAlerts.produceAlerts(alerts, ccdVisitId)
457 self.assertEqual(mock_server_check.call_count, 2)
458 self.assertEqual(producer_instance.produce.call_count, len(alerts))
459 self.assertEqual(patch_open.call_count, 1)
460 self.assertIn("123_2.avro", patch_open.call_args.args[0])
461 # Because one produce raises, we call flush one fewer times than in the success
462 # test above.
463 self.assertEqual(producer_instance.flush.call_count, len(alerts))
464 patcher.stop()
466 @patch.object(PackageAlertsTask, '_server_check')
467 def testRun_without_produce(self, mock_server_check):
468 """Test the run method of package alerts with produce set to False and
469 doWriteAlerts set to true.
470 """
471 packConfig = PackageAlertsConfig(doWriteAlerts=True)
472 tempdir = tempfile.mkdtemp(prefix='alerts')
473 packConfig.alertWriteLocation = tempdir
474 packageAlerts = PackageAlertsTask(config=packConfig)
476 packageAlerts.run(self.diaSources,
477 self.diaObjects,
478 self.diaSourceHistory,
479 self.diaForcedSources,
480 self.exposure,
481 self.exposure,
482 self.exposure)
484 self.assertEqual(mock_server_check.call_count, 0)
486 ccdVisitId = self.exposure.info.id
487 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f:
488 writer_schema, data_stream = \
489 packageAlerts.alertSchema.retrieve_alerts(f)
490 data = list(data_stream)
491 self.assertEqual(len(data), len(self.diaSources))
492 for idx, alert in enumerate(data):
493 for key, value in alert["diaSource"].items():
494 if isinstance(value, float):
495 if np.isnan(self.diaSources.iloc[idx][key]):
496 self.assertTrue(np.isnan(value))
497 else:
498 self.assertAlmostEqual(
499 1 - value / self.diaSources.iloc[idx][key],
500 0.)
501 else:
502 self.assertEqual(value, self.diaSources.iloc[idx][key])
503 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
504 alert["diaSource"]["dec"],
505 geom.degrees)
506 cutout = self.exposure.getCutout(sphPoint,
507 geom.Extent2I(self.cutoutSize,
508 self.cutoutSize))
509 ccdCutout = packageAlerts.createCcdDataCutout(
510 cutout,
511 sphPoint,
512 geom.Extent2I(self.cutoutSize, self.cutoutSize),
513 cutout.getPhotoCalib(),
514 1234)
515 self.assertEqual(alert["cutoutDifference"],
516 packageAlerts.streamCcdDataToBytes(ccdCutout))
518 shutil.rmtree(tempdir)
520 @patch.object(PackageAlertsTask, 'produceAlerts')
521 @patch('confluent_kafka.Producer')
522 @patch.object(PackageAlertsTask, '_server_check')
523 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
524 def testRun_with_produce(self, mock_produceAlerts, mock_server_check, mock_producer):
525 """Test that packageAlerts calls produceAlerts when doProduceAlerts
526 is set to True.
527 """
528 packConfig = PackageAlertsConfig(doProduceAlerts=True)
529 packageAlerts = PackageAlertsTask(config=packConfig)
531 packageAlerts.run(self.diaSources,
532 self.diaObjects,
533 self.diaSourceHistory,
534 self.diaForcedSources,
535 self.exposure,
536 self.exposure,
537 self.exposure)
538 self.assertEqual(mock_server_check.call_count, 1)
539 self.assertEqual(mock_produceAlerts.call_count, 1)
541 def test_serialize_alert_round_trip(self):
542 """Test that values in the alert packet exactly round trip.
543 """
544 packClass = PackageAlertsConfig()
545 packageAlerts = PackageAlertsTask(config=packClass)
547 alert = mock_alert(1)
548 serialized = PackageAlertsTask._serializeAlert(packageAlerts, alert)
549 deserialized = _deserialize_alert(serialized)
551 for field in alert['diaSource']:
552 self.assertEqual(alert['diaSource'][field], deserialized['diaSource'][field])
553 self.assertEqual(1, deserialized["alertId"])
555 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
556 def test_server_check(self):
558 with self.assertRaisesRegex(KafkaException, "_TRANSPORT"):
559 packConfig = PackageAlertsConfig(doProduceAlerts=True)
560 PackageAlertsTask(config=packConfig)
563class MemoryTester(lsst.utils.tests.MemoryTestCase):
564 pass
567def setup_module(module):
568 lsst.utils.tests.init()
571if __name__ == "__main__": 571 ↛ 572line 571 didn't jump to line 572, because the condition on line 571 was never true
572 lsst.utils.tests.init()
573 unittest.main()