Coverage for tests/test_packageAlerts.py: 27%
260 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-19 09:44 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-19 09:44 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import io
23import os
25import numpy as np
26import pandas as pd
27import shutil
28import tempfile
29import unittest
30from unittest.mock import patch, Mock
31from astropy import wcs
32from astropy.nddata import CCDData
33import fastavro
34try:
35 import confluent_kafka
36 from confluent_kafka import KafkaException
37except ImportError:
38 confluent_kafka = None
40import lsst.alert.packet as alertPack
41from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask
42from lsst.afw.cameraGeom.testUtils import DetectorWrapper
43import lsst.afw.image as afwImage
44import lsst.daf.base as dafBase
45from lsst.dax.apdb import Apdb, ApdbSql, ApdbSqlConfig
46import lsst.geom as geom
47import lsst.meas.base.tests
48from lsst.sphgeom import Box
49import lsst.utils.tests
50import utils_tests
53def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
54 """Run object and source catalogs through the Apdb to get the correct
55 table schemas.
57 Parameters
58 ----------
59 objects : `pandas.DataFrame`
60 Set of test DiaObjects to round trip.
61 sources : `pandas.DataFrame`
62 Set of test DiaSources to round trip.
63 forcedSources : `pandas.DataFrame`
64 Set of test DiaForcedSources to round trip.
65 dateTime : `lsst.daf.base.DateTime`
66 Time for the Apdb.
68 Returns
69 -------
70 objects : `pandas.DataFrame`
71 Round tripped objects.
72 sources : `pandas.DataFrame`
73 Round tripped sources.
74 """
75 tmpFile = tempfile.NamedTemporaryFile()
77 apdbConfig = ApdbSqlConfig()
78 apdbConfig.db_url = "sqlite:///" + tmpFile.name
79 apdbConfig.dia_object_index = "baseline"
80 apdbConfig.dia_object_columns = []
82 Apdb.makeSchema(apdbConfig)
83 apdb = ApdbSql(config=apdbConfig)
85 wholeSky = Box.full()
86 diaObjects = pd.concat([apdb.getDiaObjects(wholeSky), objects])
87 diaSources = pd.concat(
88 [apdb.getDiaSources(wholeSky, [], dateTime), sources])
89 diaForcedSources = pd.concat(
90 [apdb.getDiaForcedSources(wholeSky, [], dateTime), forcedSources])
92 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources)
94 diaObjects = apdb.getDiaObjects(wholeSky)
95 diaSources = apdb.getDiaSources(wholeSky,
96 np.unique(diaObjects["diaObjectId"]),
97 dateTime)
98 diaForcedSources = apdb.getDiaForcedSources(
99 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime)
101 diaObjects.set_index("diaObjectId", drop=False, inplace=True)
102 diaSources.set_index(["diaObjectId", "band", "diaSourceId"],
103 drop=False,
104 inplace=True)
105 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True)
107 return (diaObjects, diaSources, diaForcedSources)
110def mock_alert(alert_id):
111 """Generate a minimal mock alert.
112 """
113 return {
114 "alertId": alert_id,
115 "diaSource": {
116 "midpointMjdTai": 5,
117 "diaSourceId": 4,
118 "ccdVisitId": 2,
119 "band": 'g',
120 "ra": 12.5,
121 "dec": -16.9,
122 # These types are 32-bit floats in the avro schema, so we have to
123 # make them that type here, so that they round trip appropriately.
124 "x": np.float32(15.7),
125 "y": np.float32(89.8),
126 "apFlux": np.float32(54.85),
127 "apFluxErr": np.float32(70.0),
128 "snr": np.float32(6.7),
129 "psfFlux": np.float32(700.0),
130 "psfFluxErr": np.float32(90.0),
131 "flags": 12345,
132 }
133 }
136def _deserialize_alert(alert_bytes):
137 """Deserialize an alert message from Kafka.
139 Parameters
140 ----------
141 alert_bytes : `bytes`
142 Binary-encoding serialized Avro alert, including Confluent Wire
143 Format prefix.
145 Returns
146 -------
147 alert : `dict`
148 An alert payload.
149 """
150 schema = alertPack.Schema.from_uri(str(alertPack.get_uri_to_latest_schema()))
151 content_bytes = io.BytesIO(alert_bytes[5:])
153 return fastavro.schemaless_reader(content_bytes, schema.definition)
156class TestPackageAlerts(lsst.utils.tests.TestCase):
158 def setUp(self):
159 patcher = patch.dict(os.environ, {"AP_KAFKA_PRODUCER_PASSWORD": "fake_password",
160 "AP_KAFKA_PRODUCER_USERNAME": "fake_username",
161 "AP_KAFKA_SERVER": "fake_server",
162 "AP_KAFKA_TOPIC": "fake_topic"})
163 self.environ = patcher.start()
164 self.addCleanup(patcher.stop)
165 np.random.seed(1234)
166 self.cutoutSize = 35
167 self.center = lsst.geom.Point2D(50.1, 49.8)
168 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
169 lsst.geom.Extent2I(140, 160))
170 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
171 self.dataset.addSource(100000.0, self.center)
172 exposure, catalog = self.dataset.realize(
173 10.0,
174 self.dataset.makeMinimalSchema(),
175 randomSeed=0)
176 self.exposure = exposure
177 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector
178 self.exposure.setDetector(detector)
180 visit = afwImage.VisitInfo(
181 exposureTime=200.,
182 date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
183 dafBase.DateTime.Timescale.TAI))
184 self.exposure.info.id = 1234
185 self.exposure.info.setVisitInfo(visit)
187 self.exposure.setFilter(
188 afwImage.FilterLabel(band='g', physical="g.MP9401"))
190 diaObjects = utils_tests.makeDiaObjects(2, self.exposure)
191 diaSourceHistory = utils_tests.makeDiaSources(10,
192 diaObjects[
193 "diaObjectId"],
194 self.exposure)
195 diaForcedSources = utils_tests.makeDiaForcedSources(10,
196 diaObjects[
197 "diaObjectId"],
198 self.exposure)
199 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb(
200 diaObjects,
201 diaSourceHistory,
202 diaForcedSources,
203 self.exposure.visitInfo.date)
204 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True)
205 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True)
206 self.diaForcedSources.replace(to_replace=[None], value=np.nan,
207 inplace=True)
208 diaSourceHistory["programId"] = 0
210 self.diaSources = diaSourceHistory.loc[[(1, "g", 9), (2, "g", 10)], :]
211 self.diaSources["bboxSize"] = self.cutoutSize
212 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9),
213 (2, "g", 10)])
215 self.cutoutWcs = wcs.WCS(naxis=2)
216 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]]
217 self.cutoutWcs.wcs.crval = [
218 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(),
219 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()]
220 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
221 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
223 def testCreateExtent(self):
224 """Test the extent creation for the cutout bbox.
225 """
226 packConfig = PackageAlertsConfig()
227 # Just create a minimum less than the default cutout.
228 packConfig.minCutoutSize = self.cutoutSize - 5
229 packageAlerts = PackageAlertsTask(config=packConfig)
230 extent = packageAlerts.createDiaSourceExtent(
231 packConfig.minCutoutSize - 5)
232 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize,
233 packConfig.minCutoutSize))
234 # Test that the cutout size is correct.
235 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize)
236 self.assertTrue(extent == geom.Extent2I(self.cutoutSize,
237 self.cutoutSize))
239 def testCreateCcdDataCutout(self):
240 """Test that the data is being extracted into the CCDData cutout
241 correctly.
242 """
243 packageAlerts = PackageAlertsTask()
245 diaSrcId = 1234
246 ccdData = packageAlerts.createCcdDataCutout(
247 self.exposure,
248 self.exposure.getWcs().getSkyOrigin(),
249 self.exposure.getBBox().getDimensions(),
250 self.exposure.getPhotoCalib(),
251 diaSrcId)
252 calibExposure = self.exposure.getPhotoCalib().calibrateImage(
253 self.exposure.getMaskedImage())
255 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd,
256 self.cutoutWcs.wcs.cd)
257 self.assertFloatsAlmostEqual(ccdData.data,
258 calibExposure.getImage().array)
260 ccdData = packageAlerts.createCcdDataCutout(
261 self.exposure,
262 geom.SpherePoint(0, 0, geom.degrees),
263 self.exposure.getBBox().getDimensions(),
264 self.exposure.getPhotoCalib(),
265 diaSrcId)
266 self.assertTrue(ccdData is None)
268 def testMakeLocalTransformMatrix(self):
269 """Test that the local WCS approximation is correct.
270 """
271 packageAlerts = PackageAlertsTask()
273 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
274 cutout = self.exposure.getCutout(sphPoint,
275 geom.Extent2I(self.cutoutSize,
276 self.cutoutSize))
277 cd = packageAlerts.makeLocalTransformMatrix(
278 cutout.getWcs(), self.center, sphPoint)
279 self.assertFloatsAlmostEqual(
280 cd,
281 cutout.getWcs().getCdMatrix(),
282 rtol=1e-11,
283 atol=1e-11)
285 def testStreamCcdDataToBytes(self):
286 """Test round tripping an CCDData cutout to bytes and back.
287 """
288 packageAlerts = PackageAlertsTask()
290 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
291 cutout = self.exposure.getCutout(sphPoint,
292 geom.Extent2I(self.cutoutSize,
293 self.cutoutSize))
294 cutoutCcdData = CCDData(
295 data=cutout.getImage().array,
296 wcs=self.cutoutWcs,
297 unit="adu")
299 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData)
300 with io.BytesIO(cutoutBytes) as bytesIO:
301 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
302 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data)
304 def testMakeAlertDict(self):
305 """Test stripping data from the various data products and into a
306 dictionary "alert".
307 """
308 packageAlerts = PackageAlertsTask()
309 alertId = 1234
311 for srcIdx, diaSource in self.diaSources.iterrows():
312 sphPoint = geom.SpherePoint(diaSource["ra"],
313 diaSource["dec"],
314 geom.degrees)
315 cutout = self.exposure.getCutout(sphPoint,
316 geom.Extent2I(self.cutoutSize,
317 self.cutoutSize))
318 ccdCutout = packageAlerts.createCcdDataCutout(
319 cutout,
320 sphPoint,
321 geom.Extent2I(self.cutoutSize, self.cutoutSize),
322 cutout.getPhotoCalib(),
323 1234)
324 cutoutBytes = packageAlerts.streamCcdDataToBytes(
325 ccdCutout)
326 objSources = self.diaSourceHistory.loc[srcIdx[0]]
327 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
328 alert = packageAlerts.makeAlertDict(
329 alertId,
330 diaSource,
331 self.diaObjects.loc[srcIdx[0]],
332 objSources,
333 objForcedSources,
334 ccdCutout,
335 ccdCutout)
336 self.assertEqual(len(alert), 9)
338 self.assertEqual(alert["alertId"], alertId)
339 self.assertEqual(alert["diaSource"], diaSource.to_dict())
340 self.assertEqual(alert["cutoutDifference"],
341 cutoutBytes)
342 self.assertEqual(alert["cutoutTemplate"],
343 cutoutBytes)
345 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
346 def test_produceAlerts_empty_password(self):
347 """ Test that produceAlerts raises if the password is empty or missing.
348 """
349 self.environ['AP_KAFKA_PRODUCER_PASSWORD'] = ""
350 with self.assertRaisesRegex(ValueError, "Kafka password"):
351 packConfig = PackageAlertsConfig(doProduceAlerts=True)
352 PackageAlertsTask(config=packConfig)
354 del self.environ['AP_KAFKA_PRODUCER_PASSWORD']
355 with self.assertRaisesRegex(ValueError, "Kafka password"):
356 packConfig = PackageAlertsConfig(doProduceAlerts=True)
357 PackageAlertsTask(config=packConfig)
359 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
360 def test_produceAlerts_empty_username(self):
361 """ Test that produceAlerts raises if the username is empty or missing.
362 """
363 self.environ['AP_KAFKA_PRODUCER_USERNAME'] = ""
364 with self.assertRaisesRegex(ValueError, "Kafka username"):
365 packConfig = PackageAlertsConfig(doProduceAlerts=True)
366 PackageAlertsTask(config=packConfig)
368 del self.environ['AP_KAFKA_PRODUCER_USERNAME']
369 with self.assertRaisesRegex(ValueError, "Kafka username"):
370 packConfig = PackageAlertsConfig(doProduceAlerts=True)
371 PackageAlertsTask(config=packConfig)
373 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
374 def test_produceAlerts_empty_server(self):
375 """ Test that produceAlerts raises if the server is empty or missing.
376 """
377 self.environ['AP_KAFKA_SERVER'] = ""
378 with self.assertRaisesRegex(ValueError, "Kafka server"):
379 packConfig = PackageAlertsConfig(doProduceAlerts=True)
380 PackageAlertsTask(config=packConfig)
382 del self.environ['AP_KAFKA_SERVER']
383 with self.assertRaisesRegex(ValueError, "Kafka server"):
384 packConfig = PackageAlertsConfig(doProduceAlerts=True)
385 PackageAlertsTask(config=packConfig)
387 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
388 def test_produceAlerts_empty_topic(self):
389 """ Test that produceAlerts raises if the topic is empty or missing.
390 """
391 self.environ['AP_KAFKA_TOPIC'] = ""
392 with self.assertRaisesRegex(ValueError, "Kafka topic"):
393 packConfig = PackageAlertsConfig(doProduceAlerts=True)
394 PackageAlertsTask(config=packConfig)
396 del self.environ['AP_KAFKA_TOPIC']
397 with self.assertRaisesRegex(ValueError, "Kafka topic"):
398 packConfig = PackageAlertsConfig(doProduceAlerts=True)
399 PackageAlertsTask(config=packConfig)
401 @patch('confluent_kafka.Producer')
402 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
403 def test_produceAlerts_success(self, mock_producer):
404 """ Test that produceAlerts calls the producer on all provided alerts
405 when the alerts are all under the batch size limit.
406 """
407 packConfig = PackageAlertsConfig(doProduceAlerts=True)
408 packageAlerts = PackageAlertsTask(config=packConfig)
409 alerts = [mock_alert(1), mock_alert(2)]
410 ccdVisitId = 123
412 # Create a variable and assign it an instance of the patched kafka producer
413 producer_instance = mock_producer.return_value
414 producer_instance.produce = Mock()
415 producer_instance.flush = Mock()
416 packageAlerts.produceAlerts(alerts, ccdVisitId)
418 self.assertEqual(producer_instance.produce.call_count, len(alerts))
419 self.assertEqual(producer_instance.flush.call_count, len(alerts)+1)
421 @patch('confluent_kafka.Producer')
422 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
423 def test_produceAlerts_one_failure(self, mock_producer):
424 """ Test that produceAlerts correctly fails on one alert
425 and is writing the failure to disk.
426 """
427 counter = 0
429 def mock_produce(*args, **kwargs):
430 nonlocal counter
431 counter += 1
432 if counter == 2:
433 raise KafkaException
434 else:
435 return
437 packConfig = PackageAlertsConfig(doProduceAlerts=True)
438 packageAlerts = PackageAlertsTask(config=packConfig)
440 patcher = patch("builtins.open")
441 patch_open = patcher.start()
442 alerts = [mock_alert(1), mock_alert(2), mock_alert(3)]
443 ccdVisitId = 123
445 producer_instance = mock_producer.return_value
446 producer_instance.produce = Mock(side_effect=mock_produce)
447 producer_instance.flush = Mock()
449 packageAlerts.produceAlerts(alerts, ccdVisitId)
451 self.assertEqual(producer_instance.produce.call_count, len(alerts))
452 self.assertEqual(patch_open.call_count, 1)
453 self.assertIn("123_2.avro", patch_open.call_args.args[0])
454 # Because one produce raises, we call flush one fewer times than in the success
455 # test above.
456 self.assertEqual(producer_instance.flush.call_count, len(alerts))
457 patcher.stop()
459 def testRun_without_produce(self):
460 """Test the run method of package alerts with produce set to False and
461 doWriteAlerts set to true.
462 """
464 packConfig = PackageAlertsConfig(doWriteAlerts=True)
465 tempdir = tempfile.mkdtemp(prefix='alerts')
466 packConfig.alertWriteLocation = tempdir
467 packageAlerts = PackageAlertsTask(config=packConfig)
469 packageAlerts.run(self.diaSources,
470 self.diaObjects,
471 self.diaSourceHistory,
472 self.diaForcedSources,
473 self.exposure,
474 self.exposure)
476 ccdVisitId = self.exposure.info.id
477 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f:
478 writer_schema, data_stream = \
479 packageAlerts.alertSchema.retrieve_alerts(f)
480 data = list(data_stream)
481 self.assertEqual(len(data), len(self.diaSources))
482 for idx, alert in enumerate(data):
483 for key, value in alert["diaSource"].items():
484 if isinstance(value, float):
485 if np.isnan(self.diaSources.iloc[idx][key]):
486 self.assertTrue(np.isnan(value))
487 else:
488 self.assertAlmostEqual(
489 1 - value / self.diaSources.iloc[idx][key],
490 0.)
491 else:
492 self.assertEqual(value, self.diaSources.iloc[idx][key])
493 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
494 alert["diaSource"]["dec"],
495 geom.degrees)
496 cutout = self.exposure.getCutout(sphPoint,
497 geom.Extent2I(self.cutoutSize,
498 self.cutoutSize))
499 ccdCutout = packageAlerts.createCcdDataCutout(
500 cutout,
501 sphPoint,
502 geom.Extent2I(self.cutoutSize, self.cutoutSize),
503 cutout.getPhotoCalib(),
504 1234)
505 self.assertEqual(alert["cutoutDifference"],
506 packageAlerts.streamCcdDataToBytes(ccdCutout))
508 shutil.rmtree(tempdir)
510 @patch.object(PackageAlertsTask, 'produceAlerts')
511 @patch('confluent_kafka.Producer')
512 @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
513 def testRun_with_produce(self, mock_produceAlerts, mock_producer):
514 """Test that packageAlerts calls produceAlerts when doProduceAlerts
515 is set to True.
516 """
517 packConfig = PackageAlertsConfig(doProduceAlerts=True)
518 packageAlerts = PackageAlertsTask(config=packConfig)
520 packageAlerts.run(self.diaSources,
521 self.diaObjects,
522 self.diaSourceHistory,
523 self.diaForcedSources,
524 self.exposure,
525 self.exposure)
527 self.assertEqual(mock_produceAlerts.call_count, 1)
529 def test_serialize_alert_round_trip(self):
530 """Test that values in the alert packet exactly round trip.
531 """
532 ConfigClass = PackageAlertsConfig()
533 packageAlerts = PackageAlertsTask(config=ConfigClass)
535 alert = mock_alert(1)
536 serialized = PackageAlertsTask._serializeAlert(packageAlerts, alert)
537 deserialized = _deserialize_alert(serialized)
539 for field in alert['diaSource']:
540 self.assertEqual(alert['diaSource'][field], deserialized['diaSource'][field])
541 self.assertEqual(1, deserialized["alertId"])
544class MemoryTester(lsst.utils.tests.MemoryTestCase):
545 pass
548def setup_module(module):
549 lsst.utils.tests.init()
552if __name__ == "__main__": 552 ↛ 553line 552 didn't jump to line 553, because the condition on line 552 was never true
553 lsst.utils.tests.init()
554 unittest.main()