Coverage for python/lsst/ap/association/packageAlerts.py: 21%
182 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-27 03:37 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-27 03:37 -0700
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ("PackageAlertsConfig", "PackageAlertsTask")
24import io
25import os
26import sys
28from astropy import wcs
29import astropy.units as u
30from astropy.nddata import CCDData, VarianceUncertainty
31import pandas as pd
32import struct
33import fastavro
34# confluent_kafka is not in the standard Rubin environment as it is a third
35# party package and is only needed when producing alerts.
36try:
37 import confluent_kafka
38 from confluent_kafka import KafkaException
39 from confluent_kafka.admin import AdminClient
40except ImportError:
41 confluent_kafka = None
43import lsst.alert.packet as alertPack
44import lsst.afw.geom as afwGeom
45import lsst.geom as geom
46import lsst.pex.config as pexConfig
47from lsst.pex.exceptions import InvalidParameterError
48import lsst.pipe.base as pipeBase
49from lsst.utils.timer import timeMethod
52class PackageAlertsConfig(pexConfig.Config):
53 """Config class for AssociationTask.
54 """
55 schemaFile = pexConfig.Field(
56 dtype=str,
57 doc="Schema definition file URI for the avro alerts.",
58 default=str(alertPack.get_uri_to_latest_schema())
59 )
60 minCutoutSize = pexConfig.RangeField(
61 dtype=int,
62 min=0,
63 max=1000,
64 default=30,
65 doc="Dimension of the square image cutouts to package in the alert."
66 )
67 alertWriteLocation = pexConfig.Field(
68 dtype=str,
69 doc="Location to write alerts to.",
70 default=os.path.join(os.getcwd(), "alerts"),
71 )
73 doProduceAlerts = pexConfig.Field(
74 dtype=bool,
75 doc="Turn on alert production to kafka if true and if confluent_kafka is in the environment.",
76 default=False,
77 )
79 doWriteAlerts = pexConfig.Field(
80 dtype=bool,
81 doc="Write alerts to disk if true.",
82 default=False,
83 )
85 doWriteFailedAlerts = pexConfig.Field(
86 dtype=bool,
87 doc="If an alert cannot be sent when doProduceAlerts is set, "
88 "write it to disk for debugging purposes.",
89 default=False,
90 )
93class PackageAlertsTask(pipeBase.Task):
94 """Tasks for packaging Dia and Pipelines data into Avro alert packages.
95 """
96 ConfigClass = PackageAlertsConfig
97 _DefaultName = "packageAlerts"
99 _scale = (1.0 * geom.arcseconds).asDegrees()
101 def __init__(self, **kwargs):
102 super().__init__(**kwargs)
103 self.alertSchema = alertPack.Schema.from_uri(self.config.schemaFile)
104 os.makedirs(self.config.alertWriteLocation, exist_ok=True)
106 if self.config.doProduceAlerts:
107 if confluent_kafka is not None:
108 self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD")
109 if not self.password:
110 raise ValueError("Kafka password environment variable was not set.")
111 self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME")
112 if not self.username:
113 raise ValueError("Kafka username environment variable was not set.")
114 self.server = os.getenv("AP_KAFKA_SERVER")
115 if not self.server:
116 raise ValueError("Kafka server environment variable was not set.")
117 self.kafkaTopic = os.getenv("AP_KAFKA_TOPIC")
118 if not self.kafkaTopic:
119 raise ValueError("Kafka topic environment variable was not set.")
121 # confluent_kafka configures all of its classes with dictionaries. This one
122 # sets up the bare minimum that is needed.
123 self.kafkaConfig = {
124 # This is the URL to use to connect to the Kafka cluster.
125 "bootstrap.servers": self.server,
126 # These next two properties tell the Kafka client about the specific
127 # authentication and authorization protocols that should be used when
128 # connecting.
129 "security.protocol": "SASL_PLAINTEXT",
130 "sasl.mechanisms": "SCRAM-SHA-512",
131 # The sasl.username and sasl.password are passed through over
132 # SCRAM-SHA-512 auth to connect to the cluster. The username is not
133 # sensitive, but the password is (of course) a secret value which
134 # should never be committed to source code.
135 "sasl.username": self.username,
136 "sasl.password": self.password,
137 # Batch size limits the largest size of a kafka alert that can be sent.
138 # We set the batch size to 2 Mb.
139 "batch.size": 2097152,
140 "linger.ms": 5,
141 }
142 self.kafkaAdminConfig = {
143 # This is the URL to use to connect to the Kafka cluster.
144 "bootstrap.servers": self.server,
145 # These next two properties tell the Kafka client about the specific
146 # authentication and authorization protocols that should be used when
147 # connecting.
148 "security.protocol": "SASL_PLAINTEXT",
149 "sasl.mechanisms": "SCRAM-SHA-512",
150 # The sasl.username and sasl.password are passed through over
151 # SCRAM-SHA-512 auth to connect to the cluster. The username is not
152 # sensitive, but the password is (of course) a secret value which
153 # should never be committed to source code.
154 "sasl.username": self.username,
155 "sasl.password": self.password,
156 }
158 self._server_check()
159 self.producer = confluent_kafka.Producer(**self.kafkaConfig)
161 else:
162 raise RuntimeError("Produce alerts is set but confluent_kafka is not present in "
163 "the environment. Alerts will not be sent to the alert stream.")
165 @timeMethod
166 def run(self,
167 diaSourceCat,
168 diaObjectCat,
169 diaSrcHistory,
170 diaForcedSources,
171 diffIm,
172 calexp,
173 template,
174 doRunForcedMeasurement=True,
175 ):
176 """Package DiaSources/Object and exposure data into Avro alerts.
178 Alerts can be sent to the alert stream if ``doProduceAlerts`` is set
179 and written to disk if ``doWriteAlerts`` is set. Both can be set at the
180 same time, and are independent of one another.
182 Writes Avro alerts to a location determined by the
183 ``alertWriteLocation`` configurable.
185 Parameters
186 ----------
187 diaSourceCat : `pandas.DataFrame`
188 New DiaSources to package. DataFrame should be indexed on
189 ``["diaObjectId", "band", "diaSourceId"]``
190 diaObjectCat : `pandas.DataFrame`
191 New and updated DiaObjects matched to the new DiaSources. DataFrame
192 is indexed on ``["diaObjectId"]``
193 diaSrcHistory : `pandas.DataFrame`
194 12 month history of DiaSources matched to the DiaObjects. Excludes
195 the newest DiaSource and is indexed on
196 ``["diaObjectId", "band", "diaSourceId"]``
197 diaForcedSources : `pandas.DataFrame`
198 12 month history of DiaForcedSources matched to the DiaObjects.
199 ``["diaObjectId"]``
200 diffIm : `lsst.afw.image.ExposureF`
201 Difference image the sources in ``diaSourceCat`` were detected in.
202 calexp : `lsst.afw.image.ExposureF`
203 Calexp used to create the ``diffIm``.
204 template : `lsst.afw.image.ExposureF` or `None`
205 Template image used to create the ``diffIm``.
206 doRunForcedMeasurement : `bool`, optional
207 Flag to indicate whether forced measurement was run.
208 This should only be turned off for debugging purposes.
209 Added to allow disabling forced sources for performance
210 reasons during the ops rehearsal.
211 """
212 alerts = []
213 self._patchDiaSources(diaSourceCat)
214 self._patchDiaSources(diaSrcHistory)
215 ccdVisitId = diffIm.info.id
216 diffImPhotoCalib = diffIm.getPhotoCalib()
217 calexpPhotoCalib = calexp.getPhotoCalib()
218 templatePhotoCalib = template.getPhotoCalib()
219 for srcIndex, diaSource in diaSourceCat.iterrows():
220 # Get all diaSources for the associated diaObject.
221 # TODO: DM-31992 skip DiaSources associated with Solar System
222 # Objects for now.
223 if srcIndex[0] == 0:
224 continue
225 diaObject = diaObjectCat.loc[srcIndex[0]]
226 if diaObject["nDiaSources"] > 1:
227 objSourceHistory = diaSrcHistory.loc[srcIndex[0]]
228 else:
229 objSourceHistory = None
230 if doRunForcedMeasurement:
231 objDiaForcedSources = diaForcedSources.loc[srcIndex[0]]
232 else:
233 # Send empty table with correct columns
234 objDiaForcedSources = diaForcedSources.loc[[]]
235 sphPoint = geom.SpherePoint(diaSource["ra"],
236 diaSource["dec"],
237 geom.degrees)
239 cutoutExtent = self.createDiaSourceExtent(diaSource["bboxSize"])
240 diffImCutout = self.createCcdDataCutout(
241 diffIm,
242 sphPoint,
243 cutoutExtent,
244 diffImPhotoCalib,
245 diaSource["diaSourceId"])
246 calexpCutout = self.createCcdDataCutout(
247 calexp,
248 sphPoint,
249 cutoutExtent,
250 calexpPhotoCalib,
251 diaSource["diaSourceId"])
252 templateCutout = self.createCcdDataCutout(
253 template,
254 sphPoint,
255 cutoutExtent,
256 templatePhotoCalib,
257 diaSource["diaSourceId"])
259 # TODO: Create alertIds DM-24858
260 alertId = diaSource["diaSourceId"]
261 alerts.append(
262 self.makeAlertDict(alertId,
263 diaSource,
264 diaObject,
265 objSourceHistory,
266 objDiaForcedSources,
267 diffImCutout,
268 calexpCutout,
269 templateCutout))
271 if self.config.doProduceAlerts:
272 self.produceAlerts(alerts, ccdVisitId)
274 if self.config.doWriteAlerts:
275 with open(os.path.join(self.config.alertWriteLocation, f"{ccdVisitId}.avro"), "wb") as f:
276 self.alertSchema.store_alerts(f, alerts)
278 def _patchDiaSources(self, diaSources):
279 """Add the ``programId`` column to the data.
281 Parameters
282 ----------
283 diaSources : `pandas.DataFrame`
284 DataFrame of DiaSources to patch.
285 """
286 diaSources["programId"] = 0
288 def createDiaSourceExtent(self, bboxSize):
289 """Create an extent for a box for the cutouts given the size of the
290 square BBox that covers the source footprint.
292 Parameters
293 ----------
294 bboxSize : `int`
295 Size of a side of the square bounding box in pixels.
297 Returns
298 -------
299 extent : `lsst.geom.Extent2I`
300 Geom object representing the size of the bounding box.
301 """
302 if bboxSize < self.config.minCutoutSize:
303 extent = geom.Extent2I(self.config.minCutoutSize,
304 self.config.minCutoutSize)
305 else:
306 extent = geom.Extent2I(bboxSize, bboxSize)
307 return extent
309 def produceAlerts(self, alerts, ccdVisitId):
310 """Serialize alerts and send them to the alert stream using
311 confluent_kafka's producer.
313 Parameters
314 ----------
315 alerts : `dict`
316 Dictionary of alerts to be sent to the alert stream.
317 ccdVisitId : `int`
318 ccdVisitId of the alerts sent to the alert stream. Used to write
319 out alerts which fail to be sent to the alert stream.
320 """
321 self._server_check()
322 for alert in alerts:
323 alertBytes = self._serializeAlert(alert, schema=self.alertSchema.definition, schema_id=1)
324 try:
325 self.producer.produce(self.kafkaTopic, alertBytes, callback=self._delivery_callback)
326 self.producer.flush()
328 except KafkaException as e:
329 self.log.warning('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alertBytes)))
331 if self.config.doWriteFailedAlerts:
332 with open(os.path.join(self.config.alertWriteLocation,
333 f"{ccdVisitId}_{alert['alertId']}.avro"), "wb") as f:
334 f.write(alertBytes)
336 self.producer.flush()
338 def createCcdDataCutout(self, image, skyCenter, extent, photoCalib, srcId):
339 """Grab an image as a cutout and return a calibrated CCDData image.
341 Parameters
342 ----------
343 image : `lsst.afw.image.ExposureF`
344 Image to pull cutout from.
345 skyCenter : `lsst.geom.SpherePoint`
346 Center point of DiaSource on the sky.
347 extent : `lsst.geom.Extent2I`
348 Bounding box to cutout from the image.
349 photoCalib : `lsst.afw.image.PhotoCalib`
350 Calibrate object of the image the cutout is cut from.
351 srcId : `int`
352 Unique id of DiaSource. Used for when an error occurs extracting
353 a cutout.
355 Returns
356 -------
357 ccdData : `astropy.nddata.CCDData` or `None`
358 CCDData object storing the calibrate information from the input
359 difference or template image.
360 """
361 # Catch errors in retrieving the cutout.
362 try:
363 cutout = image.getCutout(skyCenter, extent)
364 except InvalidParameterError:
365 point = image.getWcs().skyToPixel(skyCenter)
366 imBBox = image.getBBox()
367 if not geom.Box2D(image.getBBox()).contains(point):
368 self.log.warning(
369 "DiaSource id=%i centroid lies at pixel (%.2f, %.2f) "
370 "which is outside the Exposure with bounding box "
371 "((%i, %i), (%i, %i)). Returning None for cutout...",
372 srcId, point.x, point.y,
373 imBBox.minX, imBBox.maxX, imBBox.minY, imBBox.maxY)
374 else:
375 raise InvalidParameterError(
376 "Failed to retrieve cutout from image for DiaSource with "
377 "id=%i. InvalidParameterError thrown during cutout "
378 "creation. Exiting."
379 % srcId)
380 return None
382 # Find the value of the bottom corner of our cutout's BBox and
383 # subtract 1 so that the CCDData cutout position value will be
384 # [1, 1].
385 cutOutMinX = cutout.getBBox().minX - 1
386 cutOutMinY = cutout.getBBox().minY - 1
387 center = cutout.getWcs().skyToPixel(skyCenter)
388 calibCutout = photoCalib.calibrateImage(cutout.getMaskedImage())
390 cutoutWcs = wcs.WCS(naxis=2)
391 cutoutWcs.array_shape = (cutout.getBBox().getWidth(),
392 cutout.getBBox().getWidth())
393 cutoutWcs.wcs.crpix = [center.x - cutOutMinX, center.y - cutOutMinY]
394 cutoutWcs.wcs.crval = [skyCenter.getRa().asDegrees(),
395 skyCenter.getDec().asDegrees()]
396 cutoutWcs.wcs.cd = self.makeLocalTransformMatrix(cutout.getWcs(),
397 center,
398 skyCenter)
400 return CCDData(
401 data=calibCutout.getImage().array,
402 uncertainty=VarianceUncertainty(calibCutout.getVariance().array),
403 flags=calibCutout.getMask().array,
404 wcs=cutoutWcs,
405 meta={"cutMinX": cutOutMinX,
406 "cutMinY": cutOutMinY},
407 unit=u.nJy)
409 def makeLocalTransformMatrix(self, wcs, center, skyCenter):
410 """Create a local, linear approximation of the wcs transformation
411 matrix.
413 The approximation is created as if the center is at RA=0, DEC=0. All
414 comparing x,y coordinate are relative to the position of center. Matrix
415 is initially calculated with units arcseconds and then converted to
416 degrees. This yields higher precision results due to quirks in AST.
418 Parameters
419 ----------
420 wcs : `lsst.afw.geom.SkyWcs`
421 Wcs to approximate
422 center : `lsst.geom.Point2D`
423 Point at which to evaluate the LocalWcs.
424 skyCenter : `lsst.geom.SpherePoint`
425 Point on sky to approximate the Wcs.
427 Returns
428 -------
429 localMatrix : `numpy.ndarray`
430 Matrix representation the local wcs approximation with units
431 degrees.
432 """
433 blankCDMatrix = [[self._scale, 0], [0, self._scale]]
434 localGnomonicWcs = afwGeom.makeSkyWcs(
435 center, skyCenter, blankCDMatrix)
436 measurementToLocalGnomonic = wcs.getTransform().then(
437 localGnomonicWcs.getTransform().inverted()
438 )
439 localMatrix = measurementToLocalGnomonic.getJacobian(center)
440 return localMatrix / 3600
442 def makeAlertDict(self,
443 alertId,
444 diaSource,
445 diaObject,
446 objDiaSrcHistory,
447 objDiaForcedSources,
448 diffImCutout,
449 calexpCutout,
450 templateCutout):
451 """Convert data and package into a dictionary alert.
453 Parameters
454 ----------
455 diaSource : `pandas.DataFrame`
456 New single DiaSource to package.
457 diaObject : `pandas.DataFrame`
458 DiaObject that ``diaSource`` is matched to.
459 objDiaSrcHistory : `pandas.DataFrame`
460 12 month history of ``diaObject`` excluding the latest DiaSource.
461 objDiaForcedSources : `pandas.DataFrame`
462 12 month history of ``diaObject`` forced measurements.
463 diffImCutout : `astropy.nddata.CCDData` or `None`
464 Cutout of the difference image around the location of ``diaSource``
465 with a min size set by the ``cutoutSize`` configurable.
466 calexpCutout : `astropy.nddata.CCDData` or `None`
467 Cutout of the calexp around the location of ``diaSource``
468 with a min size set by the ``cutoutSize`` configurable.
469 templateCutout : `astropy.nddata.CCDData` or `None`
470 Cutout of the template image around the location of ``diaSource``
471 with a min size set by the ``cutoutSize`` configurable.
472 """
473 alert = dict()
474 alert['alertId'] = alertId
475 alert['diaSource'] = diaSource.to_dict()
477 if objDiaSrcHistory is None:
478 alert['prvDiaSources'] = objDiaSrcHistory
479 else:
480 alert['prvDiaSources'] = objDiaSrcHistory.to_dict("records")
482 if isinstance(objDiaForcedSources, pd.Series):
483 alert['prvDiaForcedSources'] = [objDiaForcedSources.to_dict()]
484 else:
485 alert['prvDiaForcedSources'] = objDiaForcedSources.to_dict("records")
486 alert['prvDiaNondetectionLimits'] = None
488 alert['diaObject'] = diaObject.to_dict()
490 alert['ssObject'] = None
492 if diffImCutout is None:
493 alert['cutoutDifference'] = None
494 else:
495 alert['cutoutDifference'] = self.streamCcdDataToBytes(diffImCutout)
497 if calexpCutout is None:
498 alert['cutoutScience'] = None
499 else:
500 alert['cutoutScience'] = self.streamCcdDataToBytes(calexpCutout)
502 if templateCutout is None:
503 alert["cutoutTemplate"] = None
504 else:
505 alert["cutoutTemplate"] = self.streamCcdDataToBytes(templateCutout)
507 return alert
509 def streamCcdDataToBytes(self, cutout):
510 """Serialize a cutout into bytes.
512 Parameters
513 ----------
514 cutout : `astropy.nddata.CCDData`
515 Cutout to serialize.
517 Returns
518 -------
519 coutputBytes : `bytes`
520 Input cutout serialized into byte data.
521 """
522 with io.BytesIO() as streamer:
523 cutout.write(streamer, format="fits")
524 cutoutBytes = streamer.getvalue()
525 return cutoutBytes
527 def _serializeAlert(self, alert, schema=None, schema_id=0):
528 """Serialize an alert to a byte sequence for sending to Kafka.
530 Parameters
531 ----------
532 alert : `dict`
533 An alert payload to be serialized.
534 schema : `dict`, optional
535 An Avro schema definition describing how to encode `alert`. By default,
536 the schema is None, which sets it to the latest schema available.
537 schema_id : `int`, optional
538 The Confluent Schema Registry ID of the schema. By default, 0 (an
539 invalid ID) is used, indicating that the schema is not registered.
541 Returns
542 -------
543 serialized : `bytes`
544 The byte sequence describing the alert, including the Confluent Wire
545 Format prefix.
546 """
547 if schema is None:
548 schema = self.alertSchema.definition
550 buf = io.BytesIO()
551 # TODO: Use a proper schema versioning system (DM-42606)
552 buf.write(self._serializeConfluentWireHeader(schema_id))
553 fastavro.schemaless_writer(buf, schema, alert)
554 return buf.getvalue()
556 @staticmethod
557 def _serializeConfluentWireHeader(schema_version):
558 """Returns the byte prefix for Confluent Wire Format-style Kafka messages.
560 Parameters
561 ----------
562 schema_version : `int`
563 A version number which indicates the Confluent Schema Registry ID
564 number of the Avro schema used to encode the message that follows this
565 header.
567 Returns
568 -------
569 header : `bytes`
570 The 5-byte encoded message prefix.
572 Notes
573 -----
574 The Confluent Wire Format is described more fully here:
575 https://docs.confluent.io/current/schema-registry/serdes-develop/index.html#wire-format
576 """
577 ConfluentWireFormatHeader = struct.Struct(">bi")
578 return ConfluentWireFormatHeader.pack(0, schema_version)
580 def _delivery_callback(self, err, msg):
581 if err:
582 self.log.warning('Message failed delivery: %s\n' % err)
583 else:
584 self.log.debug('Message delivered to %s [%d] @ %d', msg.topic(), msg.partition(), msg.offset())
586 def _server_check(self):
587 """Checks if the alert stream credentials are still valid and the
588 server is contactable.
590 Raises
591 -------
592 KafkaException
593 Raised if the server us not contactable.
594 RuntimeError
595 Raised if the server is contactable but there are no topics
596 present.
597 """
598 admin_client = AdminClient(self.kafkaAdminConfig)
599 topics = admin_client.list_topics(timeout=0.5).topics
601 if not topics:
602 raise RuntimeError()