Coverage for python/lsst/ap/association/packageAlerts.py: 22%
183 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-25 19:01 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-25 19:01 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ("PackageAlertsConfig", "PackageAlertsTask")
24import io
25import os
26import sys
28from astropy import wcs
29import astropy.units as u
30from astropy.nddata import CCDData, VarianceUncertainty
31import pandas as pd
32import struct
33import fastavro
34# confluent_kafka is not in the standard Rubin environment as it is a third
35# party package and is only needed when producing alerts.
36try:
37 import confluent_kafka
38 from confluent_kafka import KafkaException
39 from confluent_kafka.admin import AdminClient
40except ImportError:
41 confluent_kafka = None
43import lsst.alert.packet as alertPack
44import lsst.afw.geom as afwGeom
45import lsst.geom as geom
46import lsst.pex.config as pexConfig
47from lsst.pex.exceptions import InvalidParameterError
48import lsst.pipe.base as pipeBase
49from lsst.utils.timer import timeMethod
52class PackageAlertsConfig(pexConfig.Config):
53 """Config class for AssociationTask.
54 """
55 schemaFile = pexConfig.Field(
56 dtype=str,
57 doc="Schema definition file URI for the avro alerts.",
58 default=str(alertPack.get_uri_to_latest_schema())
59 )
60 minCutoutSize = pexConfig.RangeField(
61 dtype=int,
62 min=0,
63 max=1000,
64 default=30,
65 doc="Dimension of the square image cutouts to package in the alert."
66 )
67 alertWriteLocation = pexConfig.Field(
68 dtype=str,
69 doc="Location to write alerts to.",
70 default=os.path.join(os.getcwd(), "alerts"),
71 )
73 doProduceAlerts = pexConfig.Field(
74 dtype=bool,
75 doc="Turn on alert production to kafka if true and if confluent_kafka is in the environment.",
76 default=False,
77 )
79 doWriteAlerts = pexConfig.Field(
80 dtype=bool,
81 doc="Write alerts to disk if true.",
82 default=False,
83 )
85 doWriteFailedAlerts = pexConfig.Field(
86 dtype=bool,
87 doc="If an alert cannot be sent when doProduceAlerts is set, "
88 "write it to disk for debugging purposes.",
89 default=False,
90 )
92 maxTimeout = pexConfig.Field(
93 dtype=float,
94 doc="Sets the maximum time in seconds to wait for the alert stream "
95 "broker to respond to a query before timing out.",
96 default=15.0,
97 )
99 deliveryTimeout = pexConfig.Field(
100 dtype=float,
101 doc="Sets the time to wait for the producer to wait to deliver an "
102 "alert in milliseconds.",
103 default=1200.0,
104 )
107class PackageAlertsTask(pipeBase.Task):
108 """Tasks for packaging Dia and Pipelines data into Avro alert packages.
109 """
110 ConfigClass = PackageAlertsConfig
111 _DefaultName = "packageAlerts"
113 _scale = (1.0 * geom.arcseconds).asDegrees()
115 def __init__(self, **kwargs):
116 super().__init__(**kwargs)
117 self.alertSchema = alertPack.Schema.from_uri(self.config.schemaFile)
118 os.makedirs(self.config.alertWriteLocation, exist_ok=True)
120 if self.config.doProduceAlerts:
121 if confluent_kafka is not None:
122 self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD")
123 if not self.password:
124 raise ValueError("Kafka password environment variable was not set.")
125 self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME")
126 if not self.username:
127 raise ValueError("Kafka username environment variable was not set.")
128 self.server = os.getenv("AP_KAFKA_SERVER")
129 if not self.server:
130 raise ValueError("Kafka server environment variable was not set.")
131 self.kafkaTopic = os.getenv("AP_KAFKA_TOPIC")
132 if not self.kafkaTopic:
133 raise ValueError("Kafka topic environment variable was not set.")
135 # confluent_kafka configures all of its classes with dictionaries. This one
136 # sets up the bare minimum that is needed.
137 self.kafkaConfig = {
138 # This is the URL to use to connect to the Kafka cluster.
139 "bootstrap.servers": self.server,
140 # These next two properties tell the Kafka client about the specific
141 # authentication and authorization protocols that should be used when
142 # connecting.
143 "security.protocol": "SASL_PLAINTEXT",
144 "sasl.mechanisms": "SCRAM-SHA-512",
145 # The sasl.username and sasl.password are passed through over
146 # SCRAM-SHA-512 auth to connect to the cluster. The username is not
147 # sensitive, but the password is (of course) a secret value which
148 # should never be committed to source code.
149 "sasl.username": self.username,
150 "sasl.password": self.password,
151 # Batch size limits the largest size of a kafka alert that can be sent.
152 # We set the batch size to 2 Mb.
153 "batch.size": 2097152,
154 "linger.ms": 5,
155 "delivery.timeout.ms": self.config.deliveryTimeout,
156 }
157 self.kafkaAdminConfig = {
158 # This is the URL to use to connect to the Kafka cluster.
159 "bootstrap.servers": self.server,
160 # These next two properties tell the Kafka client about the specific
161 # authentication and authorization protocols that should be used when
162 # connecting.
163 "security.protocol": "SASL_PLAINTEXT",
164 "sasl.mechanisms": "SCRAM-SHA-512",
165 # The sasl.username and sasl.password are passed through over
166 # SCRAM-SHA-512 auth to connect to the cluster. The username is not
167 # sensitive, but the password is (of course) a secret value which
168 # should never be committed to source code.
169 "sasl.username": self.username,
170 "sasl.password": self.password,
171 }
173 self._server_check()
174 self.producer = confluent_kafka.Producer(**self.kafkaConfig)
176 else:
177 raise RuntimeError("Produce alerts is set but confluent_kafka is not present in "
178 "the environment. Alerts will not be sent to the alert stream.")
180 @timeMethod
181 def run(self,
182 diaSourceCat,
183 diaObjectCat,
184 diaSrcHistory,
185 diaForcedSources,
186 diffIm,
187 calexp,
188 template,
189 doRunForcedMeasurement=True,
190 ):
191 """Package DiaSources/Object and exposure data into Avro alerts.
193 Alerts can be sent to the alert stream if ``doProduceAlerts`` is set
194 and written to disk if ``doWriteAlerts`` is set. Both can be set at the
195 same time, and are independent of one another.
197 Writes Avro alerts to a location determined by the
198 ``alertWriteLocation`` configurable.
200 Parameters
201 ----------
202 diaSourceCat : `pandas.DataFrame`
203 New DiaSources to package. DataFrame should be indexed on
204 ``["diaObjectId", "band", "diaSourceId"]``
205 diaObjectCat : `pandas.DataFrame`
206 New and updated DiaObjects matched to the new DiaSources. DataFrame
207 is indexed on ``["diaObjectId"]``
208 diaSrcHistory : `pandas.DataFrame`
209 12 month history of DiaSources matched to the DiaObjects. Excludes
210 the newest DiaSource and is indexed on
211 ``["diaObjectId", "band", "diaSourceId"]``
212 diaForcedSources : `pandas.DataFrame`
213 12 month history of DiaForcedSources matched to the DiaObjects.
214 ``["diaObjectId"]``
215 diffIm : `lsst.afw.image.ExposureF`
216 Difference image the sources in ``diaSourceCat`` were detected in.
217 calexp : `lsst.afw.image.ExposureF`
218 Calexp used to create the ``diffIm``.
219 template : `lsst.afw.image.ExposureF` or `None`
220 Template image used to create the ``diffIm``.
221 doRunForcedMeasurement : `bool`, optional
222 Flag to indicate whether forced measurement was run.
223 This should only be turned off for debugging purposes.
224 Added to allow disabling forced sources for performance
225 reasons during the ops rehearsal.
226 """
227 alerts = []
228 self._patchDiaSources(diaSourceCat)
229 self._patchDiaSources(diaSrcHistory)
230 ccdVisitId = diffIm.info.id
231 diffImPhotoCalib = diffIm.getPhotoCalib()
232 calexpPhotoCalib = calexp.getPhotoCalib()
233 templatePhotoCalib = template.getPhotoCalib()
234 for srcIndex, diaSource in diaSourceCat.iterrows():
235 # Get all diaSources for the associated diaObject.
236 # TODO: DM-31992 skip DiaSources associated with Solar System
237 # Objects for now.
238 if srcIndex[0] == 0:
239 continue
240 diaObject = diaObjectCat.loc[srcIndex[0]]
241 if diaObject["nDiaSources"] > 1:
242 objSourceHistory = diaSrcHistory.loc[srcIndex[0]]
243 else:
244 objSourceHistory = None
245 if doRunForcedMeasurement:
246 objDiaForcedSources = diaForcedSources.loc[srcIndex[0]]
247 else:
248 # Send empty table with correct columns
249 objDiaForcedSources = diaForcedSources.loc[[]]
250 sphPoint = geom.SpherePoint(diaSource["ra"],
251 diaSource["dec"],
252 geom.degrees)
254 cutoutExtent = self.createDiaSourceExtent(diaSource["bboxSize"])
255 diffImCutout = self.createCcdDataCutout(
256 diffIm,
257 sphPoint,
258 cutoutExtent,
259 diffImPhotoCalib,
260 diaSource["diaSourceId"])
261 calexpCutout = self.createCcdDataCutout(
262 calexp,
263 sphPoint,
264 cutoutExtent,
265 calexpPhotoCalib,
266 diaSource["diaSourceId"])
267 templateCutout = self.createCcdDataCutout(
268 template,
269 sphPoint,
270 cutoutExtent,
271 templatePhotoCalib,
272 diaSource["diaSourceId"])
274 # TODO: Create alertIds DM-24858
275 alertId = diaSource["diaSourceId"]
276 alerts.append(
277 self.makeAlertDict(alertId,
278 diaSource,
279 diaObject,
280 objSourceHistory,
281 objDiaForcedSources,
282 diffImCutout,
283 calexpCutout,
284 templateCutout))
286 if self.config.doProduceAlerts:
287 self.produceAlerts(alerts, ccdVisitId)
289 if self.config.doWriteAlerts:
290 with open(os.path.join(self.config.alertWriteLocation, f"{ccdVisitId}.avro"), "wb") as f:
291 self.alertSchema.store_alerts(f, alerts)
293 def _patchDiaSources(self, diaSources):
294 """Add the ``programId`` column to the data.
296 Parameters
297 ----------
298 diaSources : `pandas.DataFrame`
299 DataFrame of DiaSources to patch.
300 """
301 diaSources["programId"] = 0
303 def createDiaSourceExtent(self, bboxSize):
304 """Create an extent for a box for the cutouts given the size of the
305 square BBox that covers the source footprint.
307 Parameters
308 ----------
309 bboxSize : `int`
310 Size of a side of the square bounding box in pixels.
312 Returns
313 -------
314 extent : `lsst.geom.Extent2I`
315 Geom object representing the size of the bounding box.
316 """
317 if bboxSize < self.config.minCutoutSize:
318 extent = geom.Extent2I(self.config.minCutoutSize,
319 self.config.minCutoutSize)
320 else:
321 extent = geom.Extent2I(bboxSize, bboxSize)
322 return extent
324 def produceAlerts(self, alerts, ccdVisitId):
325 """Serialize alerts and send them to the alert stream using
326 confluent_kafka's producer.
328 Parameters
329 ----------
330 alerts : `dict`
331 Dictionary of alerts to be sent to the alert stream.
332 ccdVisitId : `int`
333 ccdVisitId of the alerts sent to the alert stream. Used to write
334 out alerts which fail to be sent to the alert stream.
335 """
336 for alert in alerts:
337 alertBytes = self._serializeAlert(alert, schema=self.alertSchema.definition, schema_id=1)
338 try:
339 self.producer.produce(self.kafkaTopic, alertBytes, callback=self._delivery_callback)
340 self.producer.flush()
342 except KafkaException as e:
343 self.log.warning('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alertBytes)))
345 if self.config.doWriteFailedAlerts:
346 with open(os.path.join(self.config.alertWriteLocation,
347 f"{ccdVisitId}_{alert['alertId']}.avro"), "wb") as f:
348 f.write(alertBytes)
350 self.producer.flush()
352 def createCcdDataCutout(self, image, skyCenter, extent, photoCalib, srcId):
353 """Grab an image as a cutout and return a calibrated CCDData image.
355 Parameters
356 ----------
357 image : `lsst.afw.image.ExposureF`
358 Image to pull cutout from.
359 skyCenter : `lsst.geom.SpherePoint`
360 Center point of DiaSource on the sky.
361 extent : `lsst.geom.Extent2I`
362 Bounding box to cutout from the image.
363 photoCalib : `lsst.afw.image.PhotoCalib`
364 Calibrate object of the image the cutout is cut from.
365 srcId : `int`
366 Unique id of DiaSource. Used for when an error occurs extracting
367 a cutout.
369 Returns
370 -------
371 ccdData : `astropy.nddata.CCDData` or `None`
372 CCDData object storing the calibrate information from the input
373 difference or template image.
374 """
375 # Catch errors in retrieving the cutout.
376 try:
377 cutout = image.getCutout(skyCenter, extent)
378 except InvalidParameterError:
379 point = image.getWcs().skyToPixel(skyCenter)
380 imBBox = image.getBBox()
381 if not geom.Box2D(image.getBBox()).contains(point):
382 self.log.warning(
383 "DiaSource id=%i centroid lies at pixel (%.2f, %.2f) "
384 "which is outside the Exposure with bounding box "
385 "((%i, %i), (%i, %i)). Returning None for cutout...",
386 srcId, point.x, point.y,
387 imBBox.minX, imBBox.maxX, imBBox.minY, imBBox.maxY)
388 else:
389 raise InvalidParameterError(
390 "Failed to retrieve cutout from image for DiaSource with "
391 "id=%i. InvalidParameterError thrown during cutout "
392 "creation. Exiting."
393 % srcId)
394 return None
396 # Find the value of the bottom corner of our cutout's BBox and
397 # subtract 1 so that the CCDData cutout position value will be
398 # [1, 1].
399 cutOutMinX = cutout.getBBox().minX - 1
400 cutOutMinY = cutout.getBBox().minY - 1
401 center = cutout.getWcs().skyToPixel(skyCenter)
402 calibCutout = photoCalib.calibrateImage(cutout.getMaskedImage())
404 cutoutWcs = wcs.WCS(naxis=2)
405 cutoutWcs.array_shape = (cutout.getBBox().getWidth(),
406 cutout.getBBox().getWidth())
407 cutoutWcs.wcs.crpix = [center.x - cutOutMinX, center.y - cutOutMinY]
408 cutoutWcs.wcs.crval = [skyCenter.getRa().asDegrees(),
409 skyCenter.getDec().asDegrees()]
410 cutoutWcs.wcs.cd = self.makeLocalTransformMatrix(cutout.getWcs(),
411 center,
412 skyCenter)
414 return CCDData(
415 data=calibCutout.getImage().array,
416 uncertainty=VarianceUncertainty(calibCutout.getVariance().array),
417 flags=calibCutout.getMask().array,
418 wcs=cutoutWcs,
419 meta={"cutMinX": cutOutMinX,
420 "cutMinY": cutOutMinY},
421 unit=u.nJy)
423 def makeLocalTransformMatrix(self, wcs, center, skyCenter):
424 """Create a local, linear approximation of the wcs transformation
425 matrix.
427 The approximation is created as if the center is at RA=0, DEC=0. All
428 comparing x,y coordinate are relative to the position of center. Matrix
429 is initially calculated with units arcseconds and then converted to
430 degrees. This yields higher precision results due to quirks in AST.
432 Parameters
433 ----------
434 wcs : `lsst.afw.geom.SkyWcs`
435 Wcs to approximate
436 center : `lsst.geom.Point2D`
437 Point at which to evaluate the LocalWcs.
438 skyCenter : `lsst.geom.SpherePoint`
439 Point on sky to approximate the Wcs.
441 Returns
442 -------
443 localMatrix : `numpy.ndarray`
444 Matrix representation the local wcs approximation with units
445 degrees.
446 """
447 blankCDMatrix = [[self._scale, 0], [0, self._scale]]
448 localGnomonicWcs = afwGeom.makeSkyWcs(
449 center, skyCenter, blankCDMatrix)
450 measurementToLocalGnomonic = wcs.getTransform().then(
451 localGnomonicWcs.getTransform().inverted()
452 )
453 localMatrix = measurementToLocalGnomonic.getJacobian(center)
454 return localMatrix / 3600
456 def makeAlertDict(self,
457 alertId,
458 diaSource,
459 diaObject,
460 objDiaSrcHistory,
461 objDiaForcedSources,
462 diffImCutout,
463 calexpCutout,
464 templateCutout):
465 """Convert data and package into a dictionary alert.
467 Parameters
468 ----------
469 diaSource : `pandas.DataFrame`
470 New single DiaSource to package.
471 diaObject : `pandas.DataFrame`
472 DiaObject that ``diaSource`` is matched to.
473 objDiaSrcHistory : `pandas.DataFrame`
474 12 month history of ``diaObject`` excluding the latest DiaSource.
475 objDiaForcedSources : `pandas.DataFrame`
476 12 month history of ``diaObject`` forced measurements.
477 diffImCutout : `astropy.nddata.CCDData` or `None`
478 Cutout of the difference image around the location of ``diaSource``
479 with a min size set by the ``cutoutSize`` configurable.
480 calexpCutout : `astropy.nddata.CCDData` or `None`
481 Cutout of the calexp around the location of ``diaSource``
482 with a min size set by the ``cutoutSize`` configurable.
483 templateCutout : `astropy.nddata.CCDData` or `None`
484 Cutout of the template image around the location of ``diaSource``
485 with a min size set by the ``cutoutSize`` configurable.
486 """
487 alert = dict()
488 alert['alertId'] = alertId
489 alert['diaSource'] = diaSource.to_dict()
491 if objDiaSrcHistory is None:
492 alert['prvDiaSources'] = objDiaSrcHistory
493 else:
494 alert['prvDiaSources'] = objDiaSrcHistory.to_dict("records")
496 if isinstance(objDiaForcedSources, pd.Series):
497 alert['prvDiaForcedSources'] = [objDiaForcedSources.to_dict()]
498 else:
499 alert['prvDiaForcedSources'] = objDiaForcedSources.to_dict("records")
500 alert['prvDiaNondetectionLimits'] = None
502 alert['diaObject'] = diaObject.to_dict()
504 alert['ssObject'] = None
506 if diffImCutout is None:
507 alert['cutoutDifference'] = None
508 else:
509 alert['cutoutDifference'] = self.streamCcdDataToBytes(diffImCutout)
511 if calexpCutout is None:
512 alert['cutoutScience'] = None
513 else:
514 alert['cutoutScience'] = self.streamCcdDataToBytes(calexpCutout)
516 if templateCutout is None:
517 alert["cutoutTemplate"] = None
518 else:
519 alert["cutoutTemplate"] = self.streamCcdDataToBytes(templateCutout)
521 return alert
523 def streamCcdDataToBytes(self, cutout):
524 """Serialize a cutout into bytes.
526 Parameters
527 ----------
528 cutout : `astropy.nddata.CCDData`
529 Cutout to serialize.
531 Returns
532 -------
533 coutputBytes : `bytes`
534 Input cutout serialized into byte data.
535 """
536 with io.BytesIO() as streamer:
537 cutout.write(streamer, format="fits")
538 cutoutBytes = streamer.getvalue()
539 return cutoutBytes
541 def _serializeAlert(self, alert, schema=None, schema_id=0):
542 """Serialize an alert to a byte sequence for sending to Kafka.
544 Parameters
545 ----------
546 alert : `dict`
547 An alert payload to be serialized.
548 schema : `dict`, optional
549 An Avro schema definition describing how to encode `alert`. By default,
550 the schema is None, which sets it to the latest schema available.
551 schema_id : `int`, optional
552 The Confluent Schema Registry ID of the schema. By default, 0 (an
553 invalid ID) is used, indicating that the schema is not registered.
555 Returns
556 -------
557 serialized : `bytes`
558 The byte sequence describing the alert, including the Confluent Wire
559 Format prefix.
560 """
561 if schema is None:
562 schema = self.alertSchema.definition
564 buf = io.BytesIO()
565 # TODO: Use a proper schema versioning system (DM-42606)
566 buf.write(self._serializeConfluentWireHeader(schema_id))
567 fastavro.schemaless_writer(buf, schema, alert)
568 return buf.getvalue()
570 @staticmethod
571 def _serializeConfluentWireHeader(schema_version):
572 """Returns the byte prefix for Confluent Wire Format-style Kafka messages.
574 Parameters
575 ----------
576 schema_version : `int`
577 A version number which indicates the Confluent Schema Registry ID
578 number of the Avro schema used to encode the message that follows this
579 header.
581 Returns
582 -------
583 header : `bytes`
584 The 5-byte encoded message prefix.
586 Notes
587 -----
588 The Confluent Wire Format is described more fully here:
589 https://docs.confluent.io/current/schema-registry/serdes-develop/index.html#wire-format
590 """
591 ConfluentWireFormatHeader = struct.Struct(">bi")
592 return ConfluentWireFormatHeader.pack(0, schema_version)
594 def _delivery_callback(self, err, msg):
595 if err:
596 self.log.warning('Message failed delivery: %s\n' % err)
597 else:
598 self.log.debug('Message delivered to %s [%d] @ %d', msg.topic(), msg.partition(), msg.offset())
600 def _server_check(self):
601 """Checks if the alert stream credentials are still valid and the
602 server is contactable.
604 Raises
605 -------
606 KafkaException
607 Raised if the server us not contactable.
608 RuntimeError
609 Raised if the server is contactable but there are no topics
610 present.
611 """
612 admin_client = AdminClient(self.kafkaAdminConfig)
613 topics = admin_client.list_topics(timeout=self.config.maxTimeout).topics
615 if not topics:
616 raise RuntimeError()