Coverage for python/lsst/ap/association/packageAlerts.py: 22%
183 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-05 03:52 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-05 03:52 -0700
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ("PackageAlertsConfig", "PackageAlertsTask")
24import io
25import os
26import sys
28from astropy import wcs
29import astropy.units as u
30from astropy.nddata import CCDData, VarianceUncertainty
31import pandas as pd
32import struct
33import fastavro
34# confluent_kafka is not in the standard Rubin environment as it is a third
35# party package and is only needed when producing alerts.
36try:
37 import confluent_kafka
38 from confluent_kafka import KafkaException
39 from confluent_kafka.admin import AdminClient
40except ImportError:
41 confluent_kafka = None
43import lsst.alert.packet as alertPack
44import lsst.afw.geom as afwGeom
45import lsst.geom as geom
46import lsst.pex.config as pexConfig
47from lsst.pex.exceptions import InvalidParameterError
48import lsst.pipe.base as pipeBase
49from lsst.utils.timer import timeMethod
52class PackageAlertsConfig(pexConfig.Config):
53 """Config class for AssociationTask.
54 """
55 schemaFile = pexConfig.Field(
56 dtype=str,
57 doc="Schema definition file URI for the avro alerts.",
58 default=str(alertPack.get_uri_to_latest_schema())
59 )
60 minCutoutSize = pexConfig.RangeField(
61 dtype=int,
62 min=0,
63 max=1000,
64 default=30,
65 doc="Dimension of the square image cutouts to package in the alert."
66 )
67 alertWriteLocation = pexConfig.Field(
68 dtype=str,
69 doc="Location to write alerts to.",
70 default=os.path.join(os.getcwd(), "alerts"),
71 )
73 doProduceAlerts = pexConfig.Field(
74 dtype=bool,
75 doc="Turn on alert production to kafka if true and if confluent_kafka is in the environment.",
76 default=False,
77 )
79 doWriteAlerts = pexConfig.Field(
80 dtype=bool,
81 doc="Write alerts to disk if true.",
82 default=False,
83 )
85 doWriteFailedAlerts = pexConfig.Field(
86 dtype=bool,
87 doc="If an alert cannot be sent when doProduceAlerts is set, "
88 "write it to disk for debugging purposes.",
89 default=False,
90 )
92 maxTimeout = pexConfig.Field(
93 dtype=float,
94 doc="Sets the maximum time in seconds to wait for the alert stream "
95 "broker to respond to a query before timing out.",
96 default=15.0,
97 )
100class PackageAlertsTask(pipeBase.Task):
101 """Tasks for packaging Dia and Pipelines data into Avro alert packages.
102 """
103 ConfigClass = PackageAlertsConfig
104 _DefaultName = "packageAlerts"
106 _scale = (1.0 * geom.arcseconds).asDegrees()
108 def __init__(self, **kwargs):
109 super().__init__(**kwargs)
110 self.alertSchema = alertPack.Schema.from_uri(self.config.schemaFile)
111 os.makedirs(self.config.alertWriteLocation, exist_ok=True)
113 if self.config.doProduceAlerts:
114 if confluent_kafka is not None:
115 self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD")
116 if not self.password:
117 raise ValueError("Kafka password environment variable was not set.")
118 self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME")
119 if not self.username:
120 raise ValueError("Kafka username environment variable was not set.")
121 self.server = os.getenv("AP_KAFKA_SERVER")
122 if not self.server:
123 raise ValueError("Kafka server environment variable was not set.")
124 self.kafkaTopic = os.getenv("AP_KAFKA_TOPIC")
125 if not self.kafkaTopic:
126 raise ValueError("Kafka topic environment variable was not set.")
128 # confluent_kafka configures all of its classes with dictionaries. This one
129 # sets up the bare minimum that is needed.
130 self.kafkaConfig = {
131 # This is the URL to use to connect to the Kafka cluster.
132 "bootstrap.servers": self.server,
133 # These next two properties tell the Kafka client about the specific
134 # authentication and authorization protocols that should be used when
135 # connecting.
136 "security.protocol": "SASL_PLAINTEXT",
137 "sasl.mechanisms": "SCRAM-SHA-512",
138 # The sasl.username and sasl.password are passed through over
139 # SCRAM-SHA-512 auth to connect to the cluster. The username is not
140 # sensitive, but the password is (of course) a secret value which
141 # should never be committed to source code.
142 "sasl.username": self.username,
143 "sasl.password": self.password,
144 # Batch size limits the largest size of a kafka alert that can be sent.
145 # We set the batch size to 2 Mb.
146 "batch.size": 2097152,
147 "linger.ms": 5,
148 }
149 self.kafkaAdminConfig = {
150 # This is the URL to use to connect to the Kafka cluster.
151 "bootstrap.servers": self.server,
152 # These next two properties tell the Kafka client about the specific
153 # authentication and authorization protocols that should be used when
154 # connecting.
155 "security.protocol": "SASL_PLAINTEXT",
156 "sasl.mechanisms": "SCRAM-SHA-512",
157 # The sasl.username and sasl.password are passed through over
158 # SCRAM-SHA-512 auth to connect to the cluster. The username is not
159 # sensitive, but the password is (of course) a secret value which
160 # should never be committed to source code.
161 "sasl.username": self.username,
162 "sasl.password": self.password,
163 }
165 self._server_check()
166 self.producer = confluent_kafka.Producer(**self.kafkaConfig)
168 else:
169 raise RuntimeError("Produce alerts is set but confluent_kafka is not present in "
170 "the environment. Alerts will not be sent to the alert stream.")
172 @timeMethod
173 def run(self,
174 diaSourceCat,
175 diaObjectCat,
176 diaSrcHistory,
177 diaForcedSources,
178 diffIm,
179 calexp,
180 template,
181 doRunForcedMeasurement=True,
182 ):
183 """Package DiaSources/Object and exposure data into Avro alerts.
185 Alerts can be sent to the alert stream if ``doProduceAlerts`` is set
186 and written to disk if ``doWriteAlerts`` is set. Both can be set at the
187 same time, and are independent of one another.
189 Writes Avro alerts to a location determined by the
190 ``alertWriteLocation`` configurable.
192 Parameters
193 ----------
194 diaSourceCat : `pandas.DataFrame`
195 New DiaSources to package. DataFrame should be indexed on
196 ``["diaObjectId", "band", "diaSourceId"]``
197 diaObjectCat : `pandas.DataFrame`
198 New and updated DiaObjects matched to the new DiaSources. DataFrame
199 is indexed on ``["diaObjectId"]``
200 diaSrcHistory : `pandas.DataFrame`
201 12 month history of DiaSources matched to the DiaObjects. Excludes
202 the newest DiaSource and is indexed on
203 ``["diaObjectId", "band", "diaSourceId"]``
204 diaForcedSources : `pandas.DataFrame`
205 12 month history of DiaForcedSources matched to the DiaObjects.
206 ``["diaObjectId"]``
207 diffIm : `lsst.afw.image.ExposureF`
208 Difference image the sources in ``diaSourceCat`` were detected in.
209 calexp : `lsst.afw.image.ExposureF`
210 Calexp used to create the ``diffIm``.
211 template : `lsst.afw.image.ExposureF` or `None`
212 Template image used to create the ``diffIm``.
213 doRunForcedMeasurement : `bool`, optional
214 Flag to indicate whether forced measurement was run.
215 This should only be turned off for debugging purposes.
216 Added to allow disabling forced sources for performance
217 reasons during the ops rehearsal.
218 """
219 alerts = []
220 self._patchDiaSources(diaSourceCat)
221 self._patchDiaSources(diaSrcHistory)
222 ccdVisitId = diffIm.info.id
223 diffImPhotoCalib = diffIm.getPhotoCalib()
224 calexpPhotoCalib = calexp.getPhotoCalib()
225 templatePhotoCalib = template.getPhotoCalib()
226 for srcIndex, diaSource in diaSourceCat.iterrows():
227 # Get all diaSources for the associated diaObject.
228 # TODO: DM-31992 skip DiaSources associated with Solar System
229 # Objects for now.
230 if srcIndex[0] == 0:
231 continue
232 diaObject = diaObjectCat.loc[srcIndex[0]]
233 if diaObject["nDiaSources"] > 1:
234 objSourceHistory = diaSrcHistory.loc[srcIndex[0]]
235 else:
236 objSourceHistory = None
237 if doRunForcedMeasurement:
238 objDiaForcedSources = diaForcedSources.loc[srcIndex[0]]
239 else:
240 # Send empty table with correct columns
241 objDiaForcedSources = diaForcedSources.loc[[]]
242 sphPoint = geom.SpherePoint(diaSource["ra"],
243 diaSource["dec"],
244 geom.degrees)
246 cutoutExtent = self.createDiaSourceExtent(diaSource["bboxSize"])
247 diffImCutout = self.createCcdDataCutout(
248 diffIm,
249 sphPoint,
250 cutoutExtent,
251 diffImPhotoCalib,
252 diaSource["diaSourceId"])
253 calexpCutout = self.createCcdDataCutout(
254 calexp,
255 sphPoint,
256 cutoutExtent,
257 calexpPhotoCalib,
258 diaSource["diaSourceId"])
259 templateCutout = self.createCcdDataCutout(
260 template,
261 sphPoint,
262 cutoutExtent,
263 templatePhotoCalib,
264 diaSource["diaSourceId"])
266 # TODO: Create alertIds DM-24858
267 alertId = diaSource["diaSourceId"]
268 alerts.append(
269 self.makeAlertDict(alertId,
270 diaSource,
271 diaObject,
272 objSourceHistory,
273 objDiaForcedSources,
274 diffImCutout,
275 calexpCutout,
276 templateCutout))
278 if self.config.doProduceAlerts:
279 self.produceAlerts(alerts, ccdVisitId)
281 if self.config.doWriteAlerts:
282 with open(os.path.join(self.config.alertWriteLocation, f"{ccdVisitId}.avro"), "wb") as f:
283 self.alertSchema.store_alerts(f, alerts)
285 def _patchDiaSources(self, diaSources):
286 """Add the ``programId`` column to the data.
288 Parameters
289 ----------
290 diaSources : `pandas.DataFrame`
291 DataFrame of DiaSources to patch.
292 """
293 diaSources["programId"] = 0
295 def createDiaSourceExtent(self, bboxSize):
296 """Create an extent for a box for the cutouts given the size of the
297 square BBox that covers the source footprint.
299 Parameters
300 ----------
301 bboxSize : `int`
302 Size of a side of the square bounding box in pixels.
304 Returns
305 -------
306 extent : `lsst.geom.Extent2I`
307 Geom object representing the size of the bounding box.
308 """
309 if bboxSize < self.config.minCutoutSize:
310 extent = geom.Extent2I(self.config.minCutoutSize,
311 self.config.minCutoutSize)
312 else:
313 extent = geom.Extent2I(bboxSize, bboxSize)
314 return extent
316 def produceAlerts(self, alerts, ccdVisitId):
317 """Serialize alerts and send them to the alert stream using
318 confluent_kafka's producer.
320 Parameters
321 ----------
322 alerts : `dict`
323 Dictionary of alerts to be sent to the alert stream.
324 ccdVisitId : `int`
325 ccdVisitId of the alerts sent to the alert stream. Used to write
326 out alerts which fail to be sent to the alert stream.
327 """
328 self._server_check()
329 for alert in alerts:
330 alertBytes = self._serializeAlert(alert, schema=self.alertSchema.definition, schema_id=1)
331 try:
332 self.producer.produce(self.kafkaTopic, alertBytes, callback=self._delivery_callback)
333 self.producer.flush()
335 except KafkaException as e:
336 self.log.warning('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alertBytes)))
338 if self.config.doWriteFailedAlerts:
339 with open(os.path.join(self.config.alertWriteLocation,
340 f"{ccdVisitId}_{alert['alertId']}.avro"), "wb") as f:
341 f.write(alertBytes)
343 self.producer.flush()
345 def createCcdDataCutout(self, image, skyCenter, extent, photoCalib, srcId):
346 """Grab an image as a cutout and return a calibrated CCDData image.
348 Parameters
349 ----------
350 image : `lsst.afw.image.ExposureF`
351 Image to pull cutout from.
352 skyCenter : `lsst.geom.SpherePoint`
353 Center point of DiaSource on the sky.
354 extent : `lsst.geom.Extent2I`
355 Bounding box to cutout from the image.
356 photoCalib : `lsst.afw.image.PhotoCalib`
357 Calibrate object of the image the cutout is cut from.
358 srcId : `int`
359 Unique id of DiaSource. Used for when an error occurs extracting
360 a cutout.
362 Returns
363 -------
364 ccdData : `astropy.nddata.CCDData` or `None`
365 CCDData object storing the calibrate information from the input
366 difference or template image.
367 """
368 # Catch errors in retrieving the cutout.
369 try:
370 cutout = image.getCutout(skyCenter, extent)
371 except InvalidParameterError:
372 point = image.getWcs().skyToPixel(skyCenter)
373 imBBox = image.getBBox()
374 if not geom.Box2D(image.getBBox()).contains(point):
375 self.log.warning(
376 "DiaSource id=%i centroid lies at pixel (%.2f, %.2f) "
377 "which is outside the Exposure with bounding box "
378 "((%i, %i), (%i, %i)). Returning None for cutout...",
379 srcId, point.x, point.y,
380 imBBox.minX, imBBox.maxX, imBBox.minY, imBBox.maxY)
381 else:
382 raise InvalidParameterError(
383 "Failed to retrieve cutout from image for DiaSource with "
384 "id=%i. InvalidParameterError thrown during cutout "
385 "creation. Exiting."
386 % srcId)
387 return None
389 # Find the value of the bottom corner of our cutout's BBox and
390 # subtract 1 so that the CCDData cutout position value will be
391 # [1, 1].
392 cutOutMinX = cutout.getBBox().minX - 1
393 cutOutMinY = cutout.getBBox().minY - 1
394 center = cutout.getWcs().skyToPixel(skyCenter)
395 calibCutout = photoCalib.calibrateImage(cutout.getMaskedImage())
397 cutoutWcs = wcs.WCS(naxis=2)
398 cutoutWcs.array_shape = (cutout.getBBox().getWidth(),
399 cutout.getBBox().getWidth())
400 cutoutWcs.wcs.crpix = [center.x - cutOutMinX, center.y - cutOutMinY]
401 cutoutWcs.wcs.crval = [skyCenter.getRa().asDegrees(),
402 skyCenter.getDec().asDegrees()]
403 cutoutWcs.wcs.cd = self.makeLocalTransformMatrix(cutout.getWcs(),
404 center,
405 skyCenter)
407 return CCDData(
408 data=calibCutout.getImage().array,
409 uncertainty=VarianceUncertainty(calibCutout.getVariance().array),
410 flags=calibCutout.getMask().array,
411 wcs=cutoutWcs,
412 meta={"cutMinX": cutOutMinX,
413 "cutMinY": cutOutMinY},
414 unit=u.nJy)
416 def makeLocalTransformMatrix(self, wcs, center, skyCenter):
417 """Create a local, linear approximation of the wcs transformation
418 matrix.
420 The approximation is created as if the center is at RA=0, DEC=0. All
421 comparing x,y coordinate are relative to the position of center. Matrix
422 is initially calculated with units arcseconds and then converted to
423 degrees. This yields higher precision results due to quirks in AST.
425 Parameters
426 ----------
427 wcs : `lsst.afw.geom.SkyWcs`
428 Wcs to approximate
429 center : `lsst.geom.Point2D`
430 Point at which to evaluate the LocalWcs.
431 skyCenter : `lsst.geom.SpherePoint`
432 Point on sky to approximate the Wcs.
434 Returns
435 -------
436 localMatrix : `numpy.ndarray`
437 Matrix representation the local wcs approximation with units
438 degrees.
439 """
440 blankCDMatrix = [[self._scale, 0], [0, self._scale]]
441 localGnomonicWcs = afwGeom.makeSkyWcs(
442 center, skyCenter, blankCDMatrix)
443 measurementToLocalGnomonic = wcs.getTransform().then(
444 localGnomonicWcs.getTransform().inverted()
445 )
446 localMatrix = measurementToLocalGnomonic.getJacobian(center)
447 return localMatrix / 3600
449 def makeAlertDict(self,
450 alertId,
451 diaSource,
452 diaObject,
453 objDiaSrcHistory,
454 objDiaForcedSources,
455 diffImCutout,
456 calexpCutout,
457 templateCutout):
458 """Convert data and package into a dictionary alert.
460 Parameters
461 ----------
462 diaSource : `pandas.DataFrame`
463 New single DiaSource to package.
464 diaObject : `pandas.DataFrame`
465 DiaObject that ``diaSource`` is matched to.
466 objDiaSrcHistory : `pandas.DataFrame`
467 12 month history of ``diaObject`` excluding the latest DiaSource.
468 objDiaForcedSources : `pandas.DataFrame`
469 12 month history of ``diaObject`` forced measurements.
470 diffImCutout : `astropy.nddata.CCDData` or `None`
471 Cutout of the difference image around the location of ``diaSource``
472 with a min size set by the ``cutoutSize`` configurable.
473 calexpCutout : `astropy.nddata.CCDData` or `None`
474 Cutout of the calexp around the location of ``diaSource``
475 with a min size set by the ``cutoutSize`` configurable.
476 templateCutout : `astropy.nddata.CCDData` or `None`
477 Cutout of the template image around the location of ``diaSource``
478 with a min size set by the ``cutoutSize`` configurable.
479 """
480 alert = dict()
481 alert['alertId'] = alertId
482 alert['diaSource'] = diaSource.to_dict()
484 if objDiaSrcHistory is None:
485 alert['prvDiaSources'] = objDiaSrcHistory
486 else:
487 alert['prvDiaSources'] = objDiaSrcHistory.to_dict("records")
489 if isinstance(objDiaForcedSources, pd.Series):
490 alert['prvDiaForcedSources'] = [objDiaForcedSources.to_dict()]
491 else:
492 alert['prvDiaForcedSources'] = objDiaForcedSources.to_dict("records")
493 alert['prvDiaNondetectionLimits'] = None
495 alert['diaObject'] = diaObject.to_dict()
497 alert['ssObject'] = None
499 if diffImCutout is None:
500 alert['cutoutDifference'] = None
501 else:
502 alert['cutoutDifference'] = self.streamCcdDataToBytes(diffImCutout)
504 if calexpCutout is None:
505 alert['cutoutScience'] = None
506 else:
507 alert['cutoutScience'] = self.streamCcdDataToBytes(calexpCutout)
509 if templateCutout is None:
510 alert["cutoutTemplate"] = None
511 else:
512 alert["cutoutTemplate"] = self.streamCcdDataToBytes(templateCutout)
514 return alert
516 def streamCcdDataToBytes(self, cutout):
517 """Serialize a cutout into bytes.
519 Parameters
520 ----------
521 cutout : `astropy.nddata.CCDData`
522 Cutout to serialize.
524 Returns
525 -------
526 coutputBytes : `bytes`
527 Input cutout serialized into byte data.
528 """
529 with io.BytesIO() as streamer:
530 cutout.write(streamer, format="fits")
531 cutoutBytes = streamer.getvalue()
532 return cutoutBytes
534 def _serializeAlert(self, alert, schema=None, schema_id=0):
535 """Serialize an alert to a byte sequence for sending to Kafka.
537 Parameters
538 ----------
539 alert : `dict`
540 An alert payload to be serialized.
541 schema : `dict`, optional
542 An Avro schema definition describing how to encode `alert`. By default,
543 the schema is None, which sets it to the latest schema available.
544 schema_id : `int`, optional
545 The Confluent Schema Registry ID of the schema. By default, 0 (an
546 invalid ID) is used, indicating that the schema is not registered.
548 Returns
549 -------
550 serialized : `bytes`
551 The byte sequence describing the alert, including the Confluent Wire
552 Format prefix.
553 """
554 if schema is None:
555 schema = self.alertSchema.definition
557 buf = io.BytesIO()
558 # TODO: Use a proper schema versioning system (DM-42606)
559 buf.write(self._serializeConfluentWireHeader(schema_id))
560 fastavro.schemaless_writer(buf, schema, alert)
561 return buf.getvalue()
563 @staticmethod
564 def _serializeConfluentWireHeader(schema_version):
565 """Returns the byte prefix for Confluent Wire Format-style Kafka messages.
567 Parameters
568 ----------
569 schema_version : `int`
570 A version number which indicates the Confluent Schema Registry ID
571 number of the Avro schema used to encode the message that follows this
572 header.
574 Returns
575 -------
576 header : `bytes`
577 The 5-byte encoded message prefix.
579 Notes
580 -----
581 The Confluent Wire Format is described more fully here:
582 https://docs.confluent.io/current/schema-registry/serdes-develop/index.html#wire-format
583 """
584 ConfluentWireFormatHeader = struct.Struct(">bi")
585 return ConfluentWireFormatHeader.pack(0, schema_version)
587 def _delivery_callback(self, err, msg):
588 if err:
589 self.log.warning('Message failed delivery: %s\n' % err)
590 else:
591 self.log.debug('Message delivered to %s [%d] @ %d', msg.topic(), msg.partition(), msg.offset())
593 def _server_check(self):
594 """Checks if the alert stream credentials are still valid and the
595 server is contactable.
597 Raises
598 -------
599 KafkaException
600 Raised if the server us not contactable.
601 RuntimeError
602 Raised if the server is contactable but there are no topics
603 present.
604 """
605 admin_client = AdminClient(self.kafkaAdminConfig)
606 topics = admin_client.list_topics(timeout=self.config.maxTimeout).topics
608 if not topics:
609 raise RuntimeError()