Coverage for python/lsst/ap/association/packageAlerts.py: 22%
184 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 09:33 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 09:33 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ("PackageAlertsConfig", "PackageAlertsTask")
24import io
25import os
26import sys
28from astropy import wcs
29import astropy.units as u
30from astropy.nddata import CCDData, VarianceUncertainty
31import pandas as pd
32import struct
33import fastavro
34# confluent_kafka is not in the standard Rubin environment as it is a third
35# party package and is only needed when producing alerts.
36try:
37 import confluent_kafka
38 from confluent_kafka import KafkaException
39 from confluent_kafka.admin import AdminClient
40except ImportError:
41 confluent_kafka = None
43import lsst.alert.packet as alertPack
44import lsst.afw.geom as afwGeom
45import lsst.geom as geom
46import lsst.pex.config as pexConfig
47from lsst.pex.exceptions import InvalidParameterError
48import lsst.pipe.base as pipeBase
49from lsst.utils.timer import timeMethod
52class PackageAlertsConfig(pexConfig.Config):
53 """Config class for AssociationTask.
54 """
55 schemaFile = pexConfig.Field(
56 dtype=str,
57 doc="Schema definition file URI for the avro alerts.",
58 default=str(alertPack.get_uri_to_latest_schema())
59 )
60 minCutoutSize = pexConfig.RangeField(
61 dtype=int,
62 min=0,
63 max=1000,
64 default=30,
65 doc="Dimension of the square image cutouts to package in the alert."
66 )
67 alertWriteLocation = pexConfig.Field(
68 dtype=str,
69 doc="Location to write alerts to.",
70 default=os.path.join(os.getcwd(), "alerts"),
71 )
73 doProduceAlerts = pexConfig.Field(
74 dtype=bool,
75 doc="Turn on alert production to kafka if true and if confluent_kafka is in the environment.",
76 default=False,
77 )
79 doWriteAlerts = pexConfig.Field(
80 dtype=bool,
81 doc="Write alerts to disk if true.",
82 default=False,
83 )
85 doWriteFailedAlerts = pexConfig.Field(
86 dtype=bool,
87 doc="If an alert cannot be sent when doProduceAlerts is set, "
88 "write it to disk for debugging purposes.",
89 default=False,
90 )
92 maxTimeout = pexConfig.Field(
93 dtype=float,
94 doc="Sets the maximum time in seconds to wait for the alert stream "
95 "broker to respond to a query before timing out.",
96 default=15.0,
97 )
99 deliveryTimeout = pexConfig.Field(
100 dtype=float,
101 doc="Sets the time to wait for the producer to wait to deliver an "
102 "alert in milliseconds.",
103 default=1200.0,
104 )
107class PackageAlertsTask(pipeBase.Task):
108 """Tasks for packaging Dia and Pipelines data into Avro alert packages.
109 """
110 ConfigClass = PackageAlertsConfig
111 _DefaultName = "packageAlerts"
113 _scale = (1.0 * geom.arcseconds).asDegrees()
115 def __init__(self, **kwargs):
116 super().__init__(**kwargs)
117 self.alertSchema = alertPack.Schema.from_uri(self.config.schemaFile)
118 os.makedirs(self.config.alertWriteLocation, exist_ok=True)
120 if self.config.doProduceAlerts:
121 if confluent_kafka is not None:
122 self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD")
123 if not self.password:
124 raise ValueError("Kafka password environment variable was not set.")
125 self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME")
126 if not self.username:
127 raise ValueError("Kafka username environment variable was not set.")
128 self.server = os.getenv("AP_KAFKA_SERVER")
129 if not self.server:
130 raise ValueError("Kafka server environment variable was not set.")
131 self.kafkaTopic = os.getenv("AP_KAFKA_TOPIC")
132 if not self.kafkaTopic:
133 raise ValueError("Kafka topic environment variable was not set.")
135 # confluent_kafka configures all of its classes with dictionaries. This one
136 # sets up the bare minimum that is needed.
137 self.kafkaConfig = {
138 # This is the URL to use to connect to the Kafka cluster.
139 "bootstrap.servers": self.server,
140 # These next two properties tell the Kafka client about the specific
141 # authentication and authorization protocols that should be used when
142 # connecting.
143 "security.protocol": "SASL_PLAINTEXT",
144 "sasl.mechanisms": "SCRAM-SHA-512",
145 # The sasl.username and sasl.password are passed through over
146 # SCRAM-SHA-512 auth to connect to the cluster. The username is not
147 # sensitive, but the password is (of course) a secret value which
148 # should never be committed to source code.
149 "sasl.username": self.username,
150 "sasl.password": self.password,
151 # Batch size limits the largest size of a kafka alert that can be sent.
152 # We set the batch size to 2 Mb.
153 "batch.size": 2097152,
154 "linger.ms": 5,
155 "delivery.timeout.ms": self.config.deliveryTimeout,
156 }
157 self.kafkaAdminConfig = {
158 # This is the URL to use to connect to the Kafka cluster.
159 "bootstrap.servers": self.server,
160 # These next two properties tell the Kafka client about the specific
161 # authentication and authorization protocols that should be used when
162 # connecting.
163 "security.protocol": "SASL_PLAINTEXT",
164 "sasl.mechanisms": "SCRAM-SHA-512",
165 # The sasl.username and sasl.password are passed through over
166 # SCRAM-SHA-512 auth to connect to the cluster. The username is not
167 # sensitive, but the password is (of course) a secret value which
168 # should never be committed to source code.
169 "sasl.username": self.username,
170 "sasl.password": self.password,
171 }
173 self._server_check()
174 self.producer = confluent_kafka.Producer(**self.kafkaConfig)
176 else:
177 raise RuntimeError("Produce alerts is set but confluent_kafka is not present in "
178 "the environment. Alerts will not be sent to the alert stream.")
180 @timeMethod
181 def run(self,
182 diaSourceCat,
183 diaObjectCat,
184 diaSrcHistory,
185 diaForcedSources,
186 diffIm,
187 calexp,
188 template,
189 doRunForcedMeasurement=True,
190 ):
191 """Package DiaSources/Object and exposure data into Avro alerts.
193 Alerts can be sent to the alert stream if ``doProduceAlerts`` is set
194 and written to disk if ``doWriteAlerts`` is set. Both can be set at the
195 same time, and are independent of one another.
197 Writes Avro alerts to a location determined by the
198 ``alertWriteLocation`` configurable.
200 Parameters
201 ----------
202 diaSourceCat : `pandas.DataFrame`
203 New DiaSources to package. DataFrame should be indexed on
204 ``["diaObjectId", "band", "diaSourceId"]``
205 diaObjectCat : `pandas.DataFrame`
206 New and updated DiaObjects matched to the new DiaSources. DataFrame
207 is indexed on ``["diaObjectId"]``
208 diaSrcHistory : `pandas.DataFrame`
209 12 month history of DiaSources matched to the DiaObjects. Excludes
210 the newest DiaSource and is indexed on
211 ``["diaObjectId", "band", "diaSourceId"]``
212 diaForcedSources : `pandas.DataFrame`
213 12 month history of DiaForcedSources matched to the DiaObjects.
214 ``["diaObjectId"]``
215 diffIm : `lsst.afw.image.ExposureF`
216 Difference image the sources in ``diaSourceCat`` were detected in.
217 calexp : `lsst.afw.image.ExposureF`
218 Calexp used to create the ``diffIm``.
219 template : `lsst.afw.image.ExposureF` or `None`
220 Template image used to create the ``diffIm``.
221 doRunForcedMeasurement : `bool`, optional
222 Flag to indicate whether forced measurement was run.
223 This should only be turned off for debugging purposes.
224 Added to allow disabling forced sources for performance
225 reasons during the ops rehearsal.
226 """
227 alerts = []
228 self._patchDiaSources(diaSourceCat)
229 self._patchDiaSources(diaSrcHistory)
230 detector = diffIm.detector.getId()
231 visit = diffIm.visitInfo.id
232 diffImPhotoCalib = diffIm.getPhotoCalib()
233 calexpPhotoCalib = calexp.getPhotoCalib()
234 templatePhotoCalib = template.getPhotoCalib()
235 for srcIndex, diaSource in diaSourceCat.iterrows():
236 # Get all diaSources for the associated diaObject.
237 # TODO: DM-31992 skip DiaSources associated with Solar System
238 # Objects for now.
239 if srcIndex[0] == 0:
240 continue
241 diaObject = diaObjectCat.loc[srcIndex[0]]
242 if diaObject["nDiaSources"] > 1:
243 objSourceHistory = diaSrcHistory.loc[srcIndex[0]]
244 else:
245 objSourceHistory = None
246 if doRunForcedMeasurement:
247 objDiaForcedSources = diaForcedSources.loc[srcIndex[0]]
248 else:
249 # Send empty table with correct columns
250 objDiaForcedSources = diaForcedSources.loc[[]]
251 sphPoint = geom.SpherePoint(diaSource["ra"],
252 diaSource["dec"],
253 geom.degrees)
255 cutoutExtent = self.createDiaSourceExtent(diaSource["bboxSize"])
256 diffImCutout = self.createCcdDataCutout(
257 diffIm,
258 sphPoint,
259 cutoutExtent,
260 diffImPhotoCalib,
261 diaSource["diaSourceId"])
262 calexpCutout = self.createCcdDataCutout(
263 calexp,
264 sphPoint,
265 cutoutExtent,
266 calexpPhotoCalib,
267 diaSource["diaSourceId"])
268 templateCutout = self.createCcdDataCutout(
269 template,
270 sphPoint,
271 cutoutExtent,
272 templatePhotoCalib,
273 diaSource["diaSourceId"])
275 # TODO: Create alertIds DM-24858
276 alertId = diaSource["diaSourceId"]
277 alerts.append(
278 self.makeAlertDict(alertId,
279 diaSource,
280 diaObject,
281 objSourceHistory,
282 objDiaForcedSources,
283 diffImCutout,
284 calexpCutout,
285 templateCutout))
287 if self.config.doProduceAlerts:
288 self.produceAlerts(alerts, visit, detector)
290 if self.config.doWriteAlerts:
291 with open(os.path.join(self.config.alertWriteLocation, f"{visit}_{detector}.avro"), "wb") as f:
292 self.alertSchema.store_alerts(f, alerts)
294 def _patchDiaSources(self, diaSources):
295 """Add the ``programId`` column to the data.
297 Parameters
298 ----------
299 diaSources : `pandas.DataFrame`
300 DataFrame of DiaSources to patch.
301 """
302 diaSources["programId"] = 0
304 def createDiaSourceExtent(self, bboxSize):
305 """Create an extent for a box for the cutouts given the size of the
306 square BBox that covers the source footprint.
308 Parameters
309 ----------
310 bboxSize : `int`
311 Size of a side of the square bounding box in pixels.
313 Returns
314 -------
315 extent : `lsst.geom.Extent2I`
316 Geom object representing the size of the bounding box.
317 """
318 if bboxSize < self.config.minCutoutSize:
319 extent = geom.Extent2I(self.config.minCutoutSize,
320 self.config.minCutoutSize)
321 else:
322 extent = geom.Extent2I(bboxSize, bboxSize)
323 return extent
325 def produceAlerts(self, alerts, visit, detector):
326 """Serialize alerts and send them to the alert stream using
327 confluent_kafka's producer.
329 Parameters
330 ----------
331 alerts : `dict`
332 Dictionary of alerts to be sent to the alert stream.
333 visit, detector : `int`
334 Visit and detector ids of these alerts. Used to write out alerts
335 which fail to be sent to the alert stream.
336 """
337 for alert in alerts:
338 alertBytes = self._serializeAlert(alert, schema=self.alertSchema.definition, schema_id=1)
339 try:
340 self.producer.produce(self.kafkaTopic, alertBytes, callback=self._delivery_callback)
341 self.producer.flush()
343 except KafkaException as e:
344 self.log.warning('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alertBytes)))
346 if self.config.doWriteFailedAlerts:
347 with open(os.path.join(self.config.alertWriteLocation,
348 f"{visit}_{detector}_{alert['alertId']}.avro"), "wb") as f:
349 f.write(alertBytes)
351 self.producer.flush()
353 def createCcdDataCutout(self, image, skyCenter, extent, photoCalib, srcId):
354 """Grab an image as a cutout and return a calibrated CCDData image.
356 Parameters
357 ----------
358 image : `lsst.afw.image.ExposureF`
359 Image to pull cutout from.
360 skyCenter : `lsst.geom.SpherePoint`
361 Center point of DiaSource on the sky.
362 extent : `lsst.geom.Extent2I`
363 Bounding box to cutout from the image.
364 photoCalib : `lsst.afw.image.PhotoCalib`
365 Calibrate object of the image the cutout is cut from.
366 srcId : `int`
367 Unique id of DiaSource. Used for when an error occurs extracting
368 a cutout.
370 Returns
371 -------
372 ccdData : `astropy.nddata.CCDData` or `None`
373 CCDData object storing the calibrate information from the input
374 difference or template image.
375 """
376 # Catch errors in retrieving the cutout.
377 try:
378 cutout = image.getCutout(skyCenter, extent)
379 except InvalidParameterError:
380 point = image.getWcs().skyToPixel(skyCenter)
381 imBBox = image.getBBox()
382 if not geom.Box2D(image.getBBox()).contains(point):
383 self.log.warning(
384 "DiaSource id=%i centroid lies at pixel (%.2f, %.2f) "
385 "which is outside the Exposure with bounding box "
386 "((%i, %i), (%i, %i)). Returning None for cutout...",
387 srcId, point.x, point.y,
388 imBBox.minX, imBBox.maxX, imBBox.minY, imBBox.maxY)
389 else:
390 raise InvalidParameterError(
391 "Failed to retrieve cutout from image for DiaSource with "
392 "id=%i. InvalidParameterError thrown during cutout "
393 "creation. Exiting."
394 % srcId)
395 return None
397 # Find the value of the bottom corner of our cutout's BBox and
398 # subtract 1 so that the CCDData cutout position value will be
399 # [1, 1].
400 cutOutMinX = cutout.getBBox().minX - 1
401 cutOutMinY = cutout.getBBox().minY - 1
402 center = cutout.getWcs().skyToPixel(skyCenter)
403 calibCutout = photoCalib.calibrateImage(cutout.getMaskedImage())
405 cutoutWcs = wcs.WCS(naxis=2)
406 cutoutWcs.array_shape = (cutout.getBBox().getWidth(),
407 cutout.getBBox().getWidth())
408 cutoutWcs.wcs.crpix = [center.x - cutOutMinX, center.y - cutOutMinY]
409 cutoutWcs.wcs.crval = [skyCenter.getRa().asDegrees(),
410 skyCenter.getDec().asDegrees()]
411 cutoutWcs.wcs.cd = self.makeLocalTransformMatrix(cutout.getWcs(),
412 center,
413 skyCenter)
415 return CCDData(
416 data=calibCutout.getImage().array,
417 uncertainty=VarianceUncertainty(calibCutout.getVariance().array),
418 flags=calibCutout.getMask().array,
419 wcs=cutoutWcs,
420 meta={"cutMinX": cutOutMinX,
421 "cutMinY": cutOutMinY},
422 unit=u.nJy)
424 def makeLocalTransformMatrix(self, wcs, center, skyCenter):
425 """Create a local, linear approximation of the wcs transformation
426 matrix.
428 The approximation is created as if the center is at RA=0, DEC=0. All
429 comparing x,y coordinate are relative to the position of center. Matrix
430 is initially calculated with units arcseconds and then converted to
431 degrees. This yields higher precision results due to quirks in AST.
433 Parameters
434 ----------
435 wcs : `lsst.afw.geom.SkyWcs`
436 Wcs to approximate
437 center : `lsst.geom.Point2D`
438 Point at which to evaluate the LocalWcs.
439 skyCenter : `lsst.geom.SpherePoint`
440 Point on sky to approximate the Wcs.
442 Returns
443 -------
444 localMatrix : `numpy.ndarray`
445 Matrix representation the local wcs approximation with units
446 degrees.
447 """
448 blankCDMatrix = [[self._scale, 0], [0, self._scale]]
449 localGnomonicWcs = afwGeom.makeSkyWcs(
450 center, skyCenter, blankCDMatrix)
451 measurementToLocalGnomonic = wcs.getTransform().then(
452 localGnomonicWcs.getTransform().inverted()
453 )
454 localMatrix = measurementToLocalGnomonic.getJacobian(center)
455 return localMatrix / 3600
457 def makeAlertDict(self,
458 alertId,
459 diaSource,
460 diaObject,
461 objDiaSrcHistory,
462 objDiaForcedSources,
463 diffImCutout,
464 calexpCutout,
465 templateCutout):
466 """Convert data and package into a dictionary alert.
468 Parameters
469 ----------
470 diaSource : `pandas.DataFrame`
471 New single DiaSource to package.
472 diaObject : `pandas.DataFrame`
473 DiaObject that ``diaSource`` is matched to.
474 objDiaSrcHistory : `pandas.DataFrame`
475 12 month history of ``diaObject`` excluding the latest DiaSource.
476 objDiaForcedSources : `pandas.DataFrame`
477 12 month history of ``diaObject`` forced measurements.
478 diffImCutout : `astropy.nddata.CCDData` or `None`
479 Cutout of the difference image around the location of ``diaSource``
480 with a min size set by the ``cutoutSize`` configurable.
481 calexpCutout : `astropy.nddata.CCDData` or `None`
482 Cutout of the calexp around the location of ``diaSource``
483 with a min size set by the ``cutoutSize`` configurable.
484 templateCutout : `astropy.nddata.CCDData` or `None`
485 Cutout of the template image around the location of ``diaSource``
486 with a min size set by the ``cutoutSize`` configurable.
487 """
488 alert = dict()
489 alert['alertId'] = alertId
490 alert['diaSource'] = diaSource.to_dict()
492 if objDiaSrcHistory is None:
493 alert['prvDiaSources'] = objDiaSrcHistory
494 else:
495 alert['prvDiaSources'] = objDiaSrcHistory.to_dict("records")
497 if isinstance(objDiaForcedSources, pd.Series):
498 alert['prvDiaForcedSources'] = [objDiaForcedSources.to_dict()]
499 else:
500 alert['prvDiaForcedSources'] = objDiaForcedSources.to_dict("records")
501 alert['prvDiaNondetectionLimits'] = None
503 alert['diaObject'] = diaObject.to_dict()
505 alert['ssObject'] = None
507 if diffImCutout is None:
508 alert['cutoutDifference'] = None
509 else:
510 alert['cutoutDifference'] = self.streamCcdDataToBytes(diffImCutout)
512 if calexpCutout is None:
513 alert['cutoutScience'] = None
514 else:
515 alert['cutoutScience'] = self.streamCcdDataToBytes(calexpCutout)
517 if templateCutout is None:
518 alert["cutoutTemplate"] = None
519 else:
520 alert["cutoutTemplate"] = self.streamCcdDataToBytes(templateCutout)
522 return alert
524 def streamCcdDataToBytes(self, cutout):
525 """Serialize a cutout into bytes.
527 Parameters
528 ----------
529 cutout : `astropy.nddata.CCDData`
530 Cutout to serialize.
532 Returns
533 -------
534 coutputBytes : `bytes`
535 Input cutout serialized into byte data.
536 """
537 with io.BytesIO() as streamer:
538 cutout.write(streamer, format="fits")
539 cutoutBytes = streamer.getvalue()
540 return cutoutBytes
542 def _serializeAlert(self, alert, schema=None, schema_id=0):
543 """Serialize an alert to a byte sequence for sending to Kafka.
545 Parameters
546 ----------
547 alert : `dict`
548 An alert payload to be serialized.
549 schema : `dict`, optional
550 An Avro schema definition describing how to encode `alert`. By default,
551 the schema is None, which sets it to the latest schema available.
552 schema_id : `int`, optional
553 The Confluent Schema Registry ID of the schema. By default, 0 (an
554 invalid ID) is used, indicating that the schema is not registered.
556 Returns
557 -------
558 serialized : `bytes`
559 The byte sequence describing the alert, including the Confluent Wire
560 Format prefix.
561 """
562 if schema is None:
563 schema = self.alertSchema.definition
565 buf = io.BytesIO()
566 # TODO: Use a proper schema versioning system (DM-42606)
567 buf.write(self._serializeConfluentWireHeader(schema_id))
568 fastavro.schemaless_writer(buf, schema, alert)
569 return buf.getvalue()
571 @staticmethod
572 def _serializeConfluentWireHeader(schema_version):
573 """Returns the byte prefix for Confluent Wire Format-style Kafka messages.
575 Parameters
576 ----------
577 schema_version : `int`
578 A version number which indicates the Confluent Schema Registry ID
579 number of the Avro schema used to encode the message that follows this
580 header.
582 Returns
583 -------
584 header : `bytes`
585 The 5-byte encoded message prefix.
587 Notes
588 -----
589 The Confluent Wire Format is described more fully here:
590 https://docs.confluent.io/current/schema-registry/serdes-develop/index.html#wire-format
591 """
592 ConfluentWireFormatHeader = struct.Struct(">bi")
593 return ConfluentWireFormatHeader.pack(0, schema_version)
595 def _delivery_callback(self, err, msg):
596 if err:
597 self.log.warning('Message failed delivery: %s\n' % err)
598 else:
599 self.log.debug('Message delivered to %s [%d] @ %d', msg.topic(), msg.partition(), msg.offset())
601 def _server_check(self):
602 """Checks if the alert stream credentials are still valid and the
603 server is contactable.
605 Raises
606 -------
607 KafkaException
608 Raised if the server us not contactable.
609 RuntimeError
610 Raised if the server is contactable but there are no topics
611 present.
612 """
613 admin_client = AdminClient(self.kafkaAdminConfig)
614 topics = admin_client.list_topics(timeout=self.config.maxTimeout).topics
616 if not topics:
617 raise RuntimeError()