Coverage for tests/test_packageAlerts.py: 27%
261 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-22 03:19 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-22 03:19 -0700
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import io
23import os
25import numpy as np
26import pandas as pd
27import shutil
28import tempfile
29import unittest
30from unittest.mock import patch, Mock
31from astropy import wcs
32from astropy.nddata import CCDData
33import fastavro
34try:
35 import confluent_kafka
36 from confluent_kafka import KafkaException
37except ImportError:
38 confluent_kafka = None
40import lsst.alert.packet as alertPack
41from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask
42from lsst.afw.cameraGeom.testUtils import DetectorWrapper
43import lsst.afw.image as afwImage
44import lsst.daf.base as dafBase
45from lsst.dax.apdb import Apdb, ApdbSql, ApdbSqlConfig
46import lsst.geom as geom
47import lsst.meas.base.tests
48from lsst.sphgeom import Box
49import lsst.utils.tests
50import utils_tests
53def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
54 """Run object and source catalogs through the Apdb to get the correct
55 table schemas.
57 Parameters
58 ----------
59 objects : `pandas.DataFrame`
60 Set of test DiaObjects to round trip.
61 sources : `pandas.DataFrame`
62 Set of test DiaSources to round trip.
63 forcedSources : `pandas.DataFrame`
64 Set of test DiaForcedSources to round trip.
65 dateTime : `lsst.daf.base.DateTime`
66 Time for the Apdb.
68 Returns
69 -------
70 objects : `pandas.DataFrame`
71 Round tripped objects.
72 sources : `pandas.DataFrame`
73 Round tripped sources.
74 """
75 tmpFile = tempfile.NamedTemporaryFile()
77 apdbConfig = ApdbSqlConfig()
78 apdbConfig.db_url = "sqlite:///" + tmpFile.name
79 apdbConfig.dia_object_index = "baseline"
80 apdbConfig.dia_object_columns = []
82 Apdb.makeSchema(apdbConfig)
83 apdb = ApdbSql(config=apdbConfig)
85 wholeSky = Box.full()
86 diaObjects = pd.concat([apdb.getDiaObjects(wholeSky), objects])
87 diaSources = pd.concat(
88 [apdb.getDiaSources(wholeSky, [], dateTime), sources])
89 diaForcedSources = pd.concat(
90 [apdb.getDiaForcedSources(wholeSky, [], dateTime), forcedSources])
92 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources)
94 diaObjects = apdb.getDiaObjects(wholeSky)
95 diaSources = apdb.getDiaSources(wholeSky,
96 np.unique(diaObjects["diaObjectId"]),
97 dateTime)
98 diaForcedSources = apdb.getDiaForcedSources(
99 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime)
101 diaObjects.set_index("diaObjectId", drop=False, inplace=True)
102 diaSources.set_index(["diaObjectId", "band", "diaSourceId"],
103 drop=False,
104 inplace=True)
105 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True)
107 return (diaObjects, diaSources, diaForcedSources)
110def mock_alert(alert_id):
111 """Generate a minimal mock alert.
112 """
113 return {
114 "alertId": alert_id,
115 "diaSource": {
116 "midpointMjdTai": 5,
117 "diaSourceId": 4,
118 "ccdVisitId": 2,
119 "band": 'g',
120 "ra": 12.5,
121 "dec": -16.9,
122 # These types are 32-bit floats in the avro schema, so we have to
123 # make them that type here, so that they round trip appropriately.
124 "x": np.float32(15.7),
125 "y": np.float32(89.8),
126 "apFlux": np.float32(54.85),
127 "apFluxErr": np.float32(70.0),
128 "snr": np.float32(6.7),
129 "psfFlux": np.float32(700.0),
130 "psfFluxErr": np.float32(90.0),
131 "flags": 12345,
132 }
133 }
136def _deserialize_alert(alert_bytes):
137 """Deserialize an alert message from Kafka.
139 Parameters
140 ----------
141 alert_bytes : `bytes`
142 Binary-encoding serialized Avro alert, including Confluent Wire
143 Format prefix.
145 Returns
146 -------
147 alert : `dict`
148 An alert payload.
149 """
150 schema = alertPack.Schema.from_uri(str(alertPack.get_uri_to_latest_schema()))
151 content_bytes = io.BytesIO(alert_bytes[5:])
153 return fastavro.schemaless_reader(content_bytes, schema.definition)
156class TestPackageAlerts(lsst.utils.tests.TestCase):
158 def setUp(self):
159 patcher = patch.dict(os.environ, {"AP_KAFKA_PRODUCER_PASSWORD": "fake_password",
160 "AP_KAFKA_PRODUCER_USERNAME": "fake_username",
161 "AP_KAFKA_SERVER": "fake_server",
162 "AP_KAFKA_TOPIC": "fake_topic"})
163 self.environ = patcher.start()
164 self.addCleanup(patcher.stop)
165 np.random.seed(1234)
166 self.cutoutSize = 35
167 self.center = lsst.geom.Point2D(50.1, 49.8)
168 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
169 lsst.geom.Extent2I(140, 160))
170 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
171 self.dataset.addSource(100000.0, self.center)
172 exposure, catalog = self.dataset.realize(
173 10.0,
174 self.dataset.makeMinimalSchema(),
175 randomSeed=0)
176 self.exposure = exposure
177 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector
178 self.exposure.setDetector(detector)
180 visit = afwImage.VisitInfo(
181 exposureTime=200.,
182 date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
183 dafBase.DateTime.Timescale.TAI))
184 self.exposure.info.id = 1234
185 self.exposure.info.setVisitInfo(visit)
187 self.exposure.setFilter(
188 afwImage.FilterLabel(band='g', physical="g.MP9401"))
190 diaObjects = utils_tests.makeDiaObjects(2, self.exposure)
191 diaSourceHistory = utils_tests.makeDiaSources(10,
192 diaObjects[
193 "diaObjectId"],
194 self.exposure)
195 diaForcedSources = utils_tests.makeDiaForcedSources(10,
196 diaObjects[
197 "diaObjectId"],
198 self.exposure)
199 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb(
200 diaObjects,
201 diaSourceHistory,
202 diaForcedSources,
203 self.exposure.visitInfo.date)
204 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True)
205 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True)
206 self.diaForcedSources.replace(to_replace=[None], value=np.nan,
207 inplace=True)
208 diaSourceHistory["programId"] = 0
210 self.diaSources = diaSourceHistory.loc[[(1, "g", 9), (2, "g", 10)], :]
211 self.diaSources["bboxSize"] = self.cutoutSize
212 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9),
213 (2, "g", 10)])
215 self.cutoutWcs = wcs.WCS(naxis=2)
216 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]]
217 self.cutoutWcs.wcs.crval = [
218 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(),
219 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()]
220 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
221 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
223 def testCreateExtent(self):
224 """Test the extent creation for the cutout bbox.
225 """
226 packConfig = PackageAlertsConfig()
227 # Just create a minimum less than the default cutout.
228 packConfig.minCutoutSize = self.cutoutSize - 5
229 packageAlerts = PackageAlertsTask(config=packConfig)
230 extent = packageAlerts.createDiaSourceExtent(
231 packConfig.minCutoutSize - 5)
232 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize,
233 packConfig.minCutoutSize))
234 # Test that the cutout size is correct.
235 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize)
236 self.assertTrue(extent == geom.Extent2I(self.cutoutSize,
237 self.cutoutSize))
239 def testCreateCcdDataCutout(self):
240 """Test that the data is being extracted into the CCDData cutout
241 correctly.
242 """
243 packageAlerts = PackageAlertsTask()
245 diaSrcId = 1234
246 ccdData = packageAlerts.createCcdDataCutout(
247 self.exposure,
248 self.exposure.getWcs().getSkyOrigin(),
249 self.exposure.getBBox().getDimensions(),
250 self.exposure.getPhotoCalib(),
251 diaSrcId)
252 calibExposure = self.exposure.getPhotoCalib().calibrateImage(
253 self.exposure.getMaskedImage())
255 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd,
256 self.cutoutWcs.wcs.cd)
257 self.assertFloatsAlmostEqual(ccdData.data,
258 calibExposure.getImage().array)
260 ccdData = packageAlerts.createCcdDataCutout(
261 self.exposure,
262 geom.SpherePoint(0, 0, geom.degrees),
263 self.exposure.getBBox().getDimensions(),
264 self.exposure.getPhotoCalib(),
265 diaSrcId)
266 self.assertTrue(ccdData is None)
268 def testMakeLocalTransformMatrix(self):
269 """Test that the local WCS approximation is correct.
270 """
271 packageAlerts = PackageAlertsTask()
273 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
274 cutout = self.exposure.getCutout(sphPoint,
275 geom.Extent2I(self.cutoutSize,
276 self.cutoutSize))
277 cd = packageAlerts.makeLocalTransformMatrix(
278 cutout.getWcs(), self.center, sphPoint)
279 self.assertFloatsAlmostEqual(
280 cd,
281 cutout.getWcs().getCdMatrix(),
282 rtol=1e-11,
283 atol=1e-11)
285 def testStreamCcdDataToBytes(self):
286 """Test round tripping an CCDData cutout to bytes and back.
287 """
288 packageAlerts = PackageAlertsTask()
290 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
291 cutout = self.exposure.getCutout(sphPoint,
292 geom.Extent2I(self.cutoutSize,
293 self.cutoutSize))
294 cutoutCcdData = CCDData(
295 data=cutout.getImage().array,
296 wcs=self.cutoutWcs,
297 unit="adu")
299 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData)
300 with io.BytesIO(cutoutBytes) as bytesIO:
301 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
302 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data)
304 def testMakeAlertDict(self):
305 """Test stripping data from the various data products and into a
306 dictionary "alert".
307 """
308 packageAlerts = PackageAlertsTask()
309 alertId = 1234
311 for srcIdx, diaSource in self.diaSources.iterrows():
312 sphPoint = geom.SpherePoint(diaSource["ra"],
313 diaSource["dec"],
314 geom.degrees)
315 cutout = self.exposure.getCutout(sphPoint,
316 geom.Extent2I(self.cutoutSize,
317 self.cutoutSize))
318 ccdCutout = packageAlerts.createCcdDataCutout(
319 cutout,
320 sphPoint,
321 geom.Extent2I(self.cutoutSize, self.cutoutSize),
322 cutout.getPhotoCalib(),
323 1234)
324 cutoutBytes = packageAlerts.streamCcdDataToBytes(
325 ccdCutout)
326 objSources = self.diaSourceHistory.loc[srcIdx[0]]
327 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
328 alert = packageAlerts.makeAlertDict(
329 alertId,
330 diaSource,
331 self.diaObjects.loc[srcIdx[0]],
332 objSources,
333 objForcedSources,
334 ccdCutout,
335 ccdCutout,
336 ccdCutout)
337 self.assertEqual(len(alert), 10)
339 self.assertEqual(alert["alertId"], alertId)
340 self.assertEqual(alert["diaSource"], diaSource.to_dict())
341 self.assertEqual(alert["cutoutDifference"],
342 cutoutBytes)
343 self.assertEqual(alert["cutoutScience"],
344 cutoutBytes)
345 self.assertEqual(alert["cutoutTemplate"],
346 cutoutBytes)
348 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
349 def test_produceAlerts_empty_password(self):
350 """ Test that produceAlerts raises if the password is empty or missing.
351 """
352 self.environ['AP_KAFKA_PRODUCER_PASSWORD'] = ""
353 with self.assertRaisesRegex(ValueError, "Kafka password"):
354 packConfig = PackageAlertsConfig(doProduceAlerts=True)
355 PackageAlertsTask(config=packConfig)
357 del self.environ['AP_KAFKA_PRODUCER_PASSWORD']
358 with self.assertRaisesRegex(ValueError, "Kafka password"):
359 packConfig = PackageAlertsConfig(doProduceAlerts=True)
360 PackageAlertsTask(config=packConfig)
362 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
363 def test_produceAlerts_empty_username(self):
364 """ Test that produceAlerts raises if the username is empty or missing.
365 """
366 self.environ['AP_KAFKA_PRODUCER_USERNAME'] = ""
367 with self.assertRaisesRegex(ValueError, "Kafka username"):
368 packConfig = PackageAlertsConfig(doProduceAlerts=True)
369 PackageAlertsTask(config=packConfig)
371 del self.environ['AP_KAFKA_PRODUCER_USERNAME']
372 with self.assertRaisesRegex(ValueError, "Kafka username"):
373 packConfig = PackageAlertsConfig(doProduceAlerts=True)
374 PackageAlertsTask(config=packConfig)
376 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
377 def test_produceAlerts_empty_server(self):
378 """ Test that produceAlerts raises if the server is empty or missing.
379 """
380 self.environ['AP_KAFKA_SERVER'] = ""
381 with self.assertRaisesRegex(ValueError, "Kafka server"):
382 packConfig = PackageAlertsConfig(doProduceAlerts=True)
383 PackageAlertsTask(config=packConfig)
385 del self.environ['AP_KAFKA_SERVER']
386 with self.assertRaisesRegex(ValueError, "Kafka server"):
387 packConfig = PackageAlertsConfig(doProduceAlerts=True)
388 PackageAlertsTask(config=packConfig)
390 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
391 def test_produceAlerts_empty_topic(self):
392 """ Test that produceAlerts raises if the topic is empty or missing.
393 """
394 self.environ['AP_KAFKA_TOPIC'] = ""
395 with self.assertRaisesRegex(ValueError, "Kafka topic"):
396 packConfig = PackageAlertsConfig(doProduceAlerts=True)
397 PackageAlertsTask(config=packConfig)
399 del self.environ['AP_KAFKA_TOPIC']
400 with self.assertRaisesRegex(ValueError, "Kafka topic"):
401 packConfig = PackageAlertsConfig(doProduceAlerts=True)
402 PackageAlertsTask(config=packConfig)
404 @patch('confluent_kafka.Producer')
405 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
406 def test_produceAlerts_success(self, mock_producer):
407 """ Test that produceAlerts calls the producer on all provided alerts
408 when the alerts are all under the batch size limit.
409 """
410 packConfig = PackageAlertsConfig(doProduceAlerts=True)
411 packageAlerts = PackageAlertsTask(config=packConfig)
412 alerts = [mock_alert(1), mock_alert(2)]
413 ccdVisitId = 123
415 # Create a variable and assign it an instance of the patched kafka producer
416 producer_instance = mock_producer.return_value
417 producer_instance.produce = Mock()
418 producer_instance.flush = Mock()
419 packageAlerts.produceAlerts(alerts, ccdVisitId)
421 self.assertEqual(producer_instance.produce.call_count, len(alerts))
422 self.assertEqual(producer_instance.flush.call_count, len(alerts)+1)
424 @patch('confluent_kafka.Producer')
425 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
426 def test_produceAlerts_one_failure(self, mock_producer):
427 """ Test that produceAlerts correctly fails on one alert
428 and is writing the failure to disk.
429 """
430 counter = 0
432 def mock_produce(*args, **kwargs):
433 nonlocal counter
434 counter += 1
435 if counter == 2:
436 raise KafkaException
437 else:
438 return
440 packConfig = PackageAlertsConfig(doProduceAlerts=True)
441 packageAlerts = PackageAlertsTask(config=packConfig)
443 patcher = patch("builtins.open")
444 patch_open = patcher.start()
445 alerts = [mock_alert(1), mock_alert(2), mock_alert(3)]
446 ccdVisitId = 123
448 producer_instance = mock_producer.return_value
449 producer_instance.produce = Mock(side_effect=mock_produce)
450 producer_instance.flush = Mock()
452 packageAlerts.produceAlerts(alerts, ccdVisitId)
454 self.assertEqual(producer_instance.produce.call_count, len(alerts))
455 self.assertEqual(patch_open.call_count, 1)
456 self.assertIn("123_2.avro", patch_open.call_args.args[0])
457 # Because one produce raises, we call flush one fewer times than in the success
458 # test above.
459 self.assertEqual(producer_instance.flush.call_count, len(alerts))
460 patcher.stop()
462 def testRun_without_produce(self):
463 """Test the run method of package alerts with produce set to False and
464 doWriteAlerts set to true.
465 """
467 packConfig = PackageAlertsConfig(doWriteAlerts=True)
468 tempdir = tempfile.mkdtemp(prefix='alerts')
469 packConfig.alertWriteLocation = tempdir
470 packageAlerts = PackageAlertsTask(config=packConfig)
472 packageAlerts.run(self.diaSources,
473 self.diaObjects,
474 self.diaSourceHistory,
475 self.diaForcedSources,
476 self.exposure,
477 self.exposure,
478 self.exposure)
480 ccdVisitId = self.exposure.info.id
481 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f:
482 writer_schema, data_stream = \
483 packageAlerts.alertSchema.retrieve_alerts(f)
484 data = list(data_stream)
485 self.assertEqual(len(data), len(self.diaSources))
486 for idx, alert in enumerate(data):
487 for key, value in alert["diaSource"].items():
488 if isinstance(value, float):
489 if np.isnan(self.diaSources.iloc[idx][key]):
490 self.assertTrue(np.isnan(value))
491 else:
492 self.assertAlmostEqual(
493 1 - value / self.diaSources.iloc[idx][key],
494 0.)
495 else:
496 self.assertEqual(value, self.diaSources.iloc[idx][key])
497 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
498 alert["diaSource"]["dec"],
499 geom.degrees)
500 cutout = self.exposure.getCutout(sphPoint,
501 geom.Extent2I(self.cutoutSize,
502 self.cutoutSize))
503 ccdCutout = packageAlerts.createCcdDataCutout(
504 cutout,
505 sphPoint,
506 geom.Extent2I(self.cutoutSize, self.cutoutSize),
507 cutout.getPhotoCalib(),
508 1234)
509 self.assertEqual(alert["cutoutDifference"],
510 packageAlerts.streamCcdDataToBytes(ccdCutout))
512 shutil.rmtree(tempdir)
514 @patch.object(PackageAlertsTask, 'produceAlerts')
515 @patch('confluent_kafka.Producer')
516 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
517 def testRun_with_produce(self, mock_produceAlerts, mock_producer):
518 """Test that packageAlerts calls produceAlerts when doProduceAlerts
519 is set to True.
520 """
521 packConfig = PackageAlertsConfig(doProduceAlerts=True)
522 packageAlerts = PackageAlertsTask(config=packConfig)
524 packageAlerts.run(self.diaSources,
525 self.diaObjects,
526 self.diaSourceHistory,
527 self.diaForcedSources,
528 self.exposure,
529 self.exposure)
531 self.assertEqual(mock_produceAlerts.call_count, 1)
533 def test_serialize_alert_round_trip(self):
534 """Test that values in the alert packet exactly round trip.
535 """
536 ConfigClass = PackageAlertsConfig()
537 packageAlerts = PackageAlertsTask(config=ConfigClass)
539 alert = mock_alert(1)
540 serialized = PackageAlertsTask._serializeAlert(packageAlerts, alert)
541 deserialized = _deserialize_alert(serialized)
543 for field in alert['diaSource']:
544 self.assertEqual(alert['diaSource'][field], deserialized['diaSource'][field])
545 self.assertEqual(1, deserialized["alertId"])
548class MemoryTester(lsst.utils.tests.MemoryTestCase):
549 pass
552def setup_module(module):
553 lsst.utils.tests.init()
556if __name__ == "__main__": 556 ↛ 557line 556 didn't jump to line 557, because the condition on line 556 was never true
557 lsst.utils.tests.init()
558 unittest.main()