Coverage for python/lsst/ap/association/packageAlerts.py: 21%
180 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-23 12:24 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-23 12:24 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ("PackageAlertsConfig", "PackageAlertsTask")
24import io
25import os
26import sys
28from astropy import wcs
29import astropy.units as u
30from astropy.nddata import CCDData, VarianceUncertainty
31import pandas as pd
32import struct
33import fastavro
34# confluent_kafka is not in the standard Rubin environment as it is a third
35# party package and is only needed when producing alerts.
36try:
37 import confluent_kafka
38 from confluent_kafka import KafkaException
39 from confluent_kafka.admin import AdminClient
40except ImportError:
41 confluent_kafka = None
43import lsst.alert.packet as alertPack
44import lsst.afw.geom as afwGeom
45import lsst.geom as geom
46import lsst.pex.config as pexConfig
47from lsst.pex.exceptions import InvalidParameterError
48import lsst.pipe.base as pipeBase
49from lsst.utils.timer import timeMethod
52class PackageAlertsConfig(pexConfig.Config):
53 """Config class for AssociationTask.
54 """
55 schemaFile = pexConfig.Field(
56 dtype=str,
57 doc="Schema definition file URI for the avro alerts.",
58 default=str(alertPack.get_uri_to_latest_schema())
59 )
60 minCutoutSize = pexConfig.RangeField(
61 dtype=int,
62 min=0,
63 max=1000,
64 default=30,
65 doc="Dimension of the square image cutouts to package in the alert."
66 )
67 alertWriteLocation = pexConfig.Field(
68 dtype=str,
69 doc="Location to write alerts to.",
70 default=os.path.join(os.getcwd(), "alerts"),
71 )
73 doProduceAlerts = pexConfig.Field(
74 dtype=bool,
75 doc="Turn on alert production to kafka if true and if confluent_kafka is in the environment.",
76 default=False,
77 )
79 doWriteAlerts = pexConfig.Field(
80 dtype=bool,
81 doc="Write alerts to disk if true.",
82 default=False,
83 )
86class PackageAlertsTask(pipeBase.Task):
87 """Tasks for packaging Dia and Pipelines data into Avro alert packages.
88 """
89 ConfigClass = PackageAlertsConfig
90 _DefaultName = "packageAlerts"
92 _scale = (1.0 * geom.arcseconds).asDegrees()
94 def __init__(self, **kwargs):
95 super().__init__(**kwargs)
96 self.alertSchema = alertPack.Schema.from_uri(self.config.schemaFile)
97 os.makedirs(self.config.alertWriteLocation, exist_ok=True)
99 if self.config.doProduceAlerts:
100 if confluent_kafka is not None:
101 self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD")
102 if not self.password:
103 raise ValueError("Kafka password environment variable was not set.")
104 self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME")
105 if not self.username:
106 raise ValueError("Kafka username environment variable was not set.")
107 self.server = os.getenv("AP_KAFKA_SERVER")
108 if not self.server:
109 raise ValueError("Kafka server environment variable was not set.")
110 self.kafkaTopic = os.getenv("AP_KAFKA_TOPIC")
111 if not self.kafkaTopic:
112 raise ValueError("Kafka topic environment variable was not set.")
114 # confluent_kafka configures all of its classes with dictionaries. This one
115 # sets up the bare minimum that is needed.
116 self.kafkaConfig = {
117 # This is the URL to use to connect to the Kafka cluster.
118 "bootstrap.servers": self.server,
119 # These next two properties tell the Kafka client about the specific
120 # authentication and authorization protocols that should be used when
121 # connecting.
122 "security.protocol": "SASL_PLAINTEXT",
123 "sasl.mechanisms": "SCRAM-SHA-512",
124 # The sasl.username and sasl.password are passed through over
125 # SCRAM-SHA-512 auth to connect to the cluster. The username is not
126 # sensitive, but the password is (of course) a secret value which
127 # should never be committed to source code.
128 "sasl.username": self.username,
129 "sasl.password": self.password,
130 # Batch size limits the largest size of a kafka alert that can be sent.
131 # We set the batch size to 2 Mb.
132 "batch.size": 2097152,
133 "linger.ms": 5,
134 }
135 self.kafkaAdminConfig = {
136 # This is the URL to use to connect to the Kafka cluster.
137 "bootstrap.servers": self.server,
138 # These next two properties tell the Kafka client about the specific
139 # authentication and authorization protocols that should be used when
140 # connecting.
141 "security.protocol": "SASL_PLAINTEXT",
142 "sasl.mechanisms": "SCRAM-SHA-512",
143 # The sasl.username and sasl.password are passed through over
144 # SCRAM-SHA-512 auth to connect to the cluster. The username is not
145 # sensitive, but the password is (of course) a secret value which
146 # should never be committed to source code.
147 "sasl.username": self.username,
148 "sasl.password": self.password,
149 }
151 self._server_check()
152 self.producer = confluent_kafka.Producer(**self.kafkaConfig)
154 else:
155 raise RuntimeError("Produce alerts is set but confluent_kafka is not present in "
156 "the environment. Alerts will not be sent to the alert stream.")
158 @timeMethod
159 def run(self,
160 diaSourceCat,
161 diaObjectCat,
162 diaSrcHistory,
163 diaForcedSources,
164 diffIm,
165 calexp,
166 template,
167 doRunForcedMeasurement=True,
168 ):
169 """Package DiaSources/Object and exposure data into Avro alerts.
171 Alerts can be sent to the alert stream if ``doProduceAlerts`` is set
172 and written to disk if ``doWriteAlerts`` is set. Both can be set at the
173 same time, and are independent of one another.
175 Writes Avro alerts to a location determined by the
176 ``alertWriteLocation`` configurable.
178 Parameters
179 ----------
180 diaSourceCat : `pandas.DataFrame`
181 New DiaSources to package. DataFrame should be indexed on
182 ``["diaObjectId", "band", "diaSourceId"]``
183 diaObjectCat : `pandas.DataFrame`
184 New and updated DiaObjects matched to the new DiaSources. DataFrame
185 is indexed on ``["diaObjectId"]``
186 diaSrcHistory : `pandas.DataFrame`
187 12 month history of DiaSources matched to the DiaObjects. Excludes
188 the newest DiaSource and is indexed on
189 ``["diaObjectId", "band", "diaSourceId"]``
190 diaForcedSources : `pandas.DataFrame`
191 12 month history of DiaForcedSources matched to the DiaObjects.
192 ``["diaObjectId"]``
193 diffIm : `lsst.afw.image.ExposureF`
194 Difference image the sources in ``diaSourceCat`` were detected in.
195 calexp : `lsst.afw.image.ExposureF`
196 Calexp used to create the ``diffIm``.
197 template : `lsst.afw.image.ExposureF` or `None`
198 Template image used to create the ``diffIm``.
199 doRunForcedMeasurement : `bool`, optional
200 Flag to indicate whether forced measurement was run.
201 This should only be turned off for debugging purposes.
202 Added to allow disabling forced sources for performance
203 reasons during the ops rehearsal.
204 """
205 alerts = []
206 self._patchDiaSources(diaSourceCat)
207 self._patchDiaSources(diaSrcHistory)
208 ccdVisitId = diffIm.info.id
209 diffImPhotoCalib = diffIm.getPhotoCalib()
210 calexpPhotoCalib = calexp.getPhotoCalib()
211 templatePhotoCalib = template.getPhotoCalib()
212 for srcIndex, diaSource in diaSourceCat.iterrows():
213 # Get all diaSources for the associated diaObject.
214 # TODO: DM-31992 skip DiaSources associated with Solar System
215 # Objects for now.
216 if srcIndex[0] == 0:
217 continue
218 diaObject = diaObjectCat.loc[srcIndex[0]]
219 if diaObject["nDiaSources"] > 1:
220 objSourceHistory = diaSrcHistory.loc[srcIndex[0]]
221 else:
222 objSourceHistory = None
223 if doRunForcedMeasurement:
224 objDiaForcedSources = diaForcedSources.loc[srcIndex[0]]
225 else:
226 # Send empty table with correct columns
227 objDiaForcedSources = diaForcedSources.loc[[]]
228 sphPoint = geom.SpherePoint(diaSource["ra"],
229 diaSource["dec"],
230 geom.degrees)
232 cutoutExtent = self.createDiaSourceExtent(diaSource["bboxSize"])
233 diffImCutout = self.createCcdDataCutout(
234 diffIm,
235 sphPoint,
236 cutoutExtent,
237 diffImPhotoCalib,
238 diaSource["diaSourceId"])
239 calexpCutout = self.createCcdDataCutout(
240 calexp,
241 sphPoint,
242 cutoutExtent,
243 calexpPhotoCalib,
244 diaSource["diaSourceId"])
245 templateCutout = self.createCcdDataCutout(
246 template,
247 sphPoint,
248 cutoutExtent,
249 templatePhotoCalib,
250 diaSource["diaSourceId"])
252 # TODO: Create alertIds DM-24858
253 alertId = diaSource["diaSourceId"]
254 alerts.append(
255 self.makeAlertDict(alertId,
256 diaSource,
257 diaObject,
258 objSourceHistory,
259 objDiaForcedSources,
260 diffImCutout,
261 calexpCutout,
262 templateCutout))
264 if self.config.doProduceAlerts:
265 self.produceAlerts(alerts, ccdVisitId)
267 if self.config.doWriteAlerts:
268 with open(os.path.join(self.config.alertWriteLocation, f"{ccdVisitId}.avro"), "wb") as f:
269 self.alertSchema.store_alerts(f, alerts)
271 def _patchDiaSources(self, diaSources):
272 """Add the ``programId`` column to the data.
274 Parameters
275 ----------
276 diaSources : `pandas.DataFrame`
277 DataFrame of DiaSources to patch.
278 """
279 diaSources["programId"] = 0
281 def createDiaSourceExtent(self, bboxSize):
282 """Create an extent for a box for the cutouts given the size of the
283 square BBox that covers the source footprint.
285 Parameters
286 ----------
287 bboxSize : `int`
288 Size of a side of the square bounding box in pixels.
290 Returns
291 -------
292 extent : `lsst.geom.Extent2I`
293 Geom object representing the size of the bounding box.
294 """
295 if bboxSize < self.config.minCutoutSize:
296 extent = geom.Extent2I(self.config.minCutoutSize,
297 self.config.minCutoutSize)
298 else:
299 extent = geom.Extent2I(bboxSize, bboxSize)
300 return extent
302 def produceAlerts(self, alerts, ccdVisitId):
303 """Serialize alerts and send them to the alert stream using
304 confluent_kafka's producer.
306 Parameters
307 ----------
308 alerts : `dict`
309 Dictionary of alerts to be sent to the alert stream.
310 ccdVisitId : `int`
311 ccdVisitId of the alerts sent to the alert stream. Used to write
312 out alerts which fail to be sent to the alert stream.
313 """
314 self._server_check()
315 for alert in alerts:
316 alertBytes = self._serializeAlert(alert, schema=self.alertSchema.definition, schema_id=1)
317 try:
318 self.producer.produce(self.kafkaTopic, alertBytes, callback=self._delivery_callback)
319 self.producer.flush()
321 except KafkaException as e:
322 self.log.warning('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alertBytes)))
323 with open(os.path.join(self.config.alertWriteLocation,
324 f"{ccdVisitId}_{alert['alertId']}.avro"), "wb") as f:
325 f.write(alertBytes)
327 self.producer.flush()
329 def createCcdDataCutout(self, image, skyCenter, extent, photoCalib, srcId):
330 """Grab an image as a cutout and return a calibrated CCDData image.
332 Parameters
333 ----------
334 image : `lsst.afw.image.ExposureF`
335 Image to pull cutout from.
336 skyCenter : `lsst.geom.SpherePoint`
337 Center point of DiaSource on the sky.
338 extent : `lsst.geom.Extent2I`
339 Bounding box to cutout from the image.
340 photoCalib : `lsst.afw.image.PhotoCalib`
341 Calibrate object of the image the cutout is cut from.
342 srcId : `int`
343 Unique id of DiaSource. Used for when an error occurs extracting
344 a cutout.
346 Returns
347 -------
348 ccdData : `astropy.nddata.CCDData` or `None`
349 CCDData object storing the calibrate information from the input
350 difference or template image.
351 """
352 # Catch errors in retrieving the cutout.
353 try:
354 cutout = image.getCutout(skyCenter, extent)
355 except InvalidParameterError:
356 point = image.getWcs().skyToPixel(skyCenter)
357 imBBox = image.getBBox()
358 if not geom.Box2D(image.getBBox()).contains(point):
359 self.log.warning(
360 "DiaSource id=%i centroid lies at pixel (%.2f, %.2f) "
361 "which is outside the Exposure with bounding box "
362 "((%i, %i), (%i, %i)). Returning None for cutout...",
363 srcId, point.x, point.y,
364 imBBox.minX, imBBox.maxX, imBBox.minY, imBBox.maxY)
365 else:
366 raise InvalidParameterError(
367 "Failed to retrieve cutout from image for DiaSource with "
368 "id=%i. InvalidParameterError thrown during cutout "
369 "creation. Exiting."
370 % srcId)
371 return None
373 # Find the value of the bottom corner of our cutout's BBox and
374 # subtract 1 so that the CCDData cutout position value will be
375 # [1, 1].
376 cutOutMinX = cutout.getBBox().minX - 1
377 cutOutMinY = cutout.getBBox().minY - 1
378 center = cutout.getWcs().skyToPixel(skyCenter)
379 calibCutout = photoCalib.calibrateImage(cutout.getMaskedImage())
381 cutoutWcs = wcs.WCS(naxis=2)
382 cutoutWcs.array_shape = (cutout.getBBox().getWidth(),
383 cutout.getBBox().getWidth())
384 cutoutWcs.wcs.crpix = [center.x - cutOutMinX, center.y - cutOutMinY]
385 cutoutWcs.wcs.crval = [skyCenter.getRa().asDegrees(),
386 skyCenter.getDec().asDegrees()]
387 cutoutWcs.wcs.cd = self.makeLocalTransformMatrix(cutout.getWcs(),
388 center,
389 skyCenter)
391 return CCDData(
392 data=calibCutout.getImage().array,
393 uncertainty=VarianceUncertainty(calibCutout.getVariance().array),
394 flags=calibCutout.getMask().array,
395 wcs=cutoutWcs,
396 meta={"cutMinX": cutOutMinX,
397 "cutMinY": cutOutMinY},
398 unit=u.nJy)
400 def makeLocalTransformMatrix(self, wcs, center, skyCenter):
401 """Create a local, linear approximation of the wcs transformation
402 matrix.
404 The approximation is created as if the center is at RA=0, DEC=0. All
405 comparing x,y coordinate are relative to the position of center. Matrix
406 is initially calculated with units arcseconds and then converted to
407 degrees. This yields higher precision results due to quirks in AST.
409 Parameters
410 ----------
411 wcs : `lsst.afw.geom.SkyWcs`
412 Wcs to approximate
413 center : `lsst.geom.Point2D`
414 Point at which to evaluate the LocalWcs.
415 skyCenter : `lsst.geom.SpherePoint`
416 Point on sky to approximate the Wcs.
418 Returns
419 -------
420 localMatrix : `numpy.ndarray`
421 Matrix representation the local wcs approximation with units
422 degrees.
423 """
424 blankCDMatrix = [[self._scale, 0], [0, self._scale]]
425 localGnomonicWcs = afwGeom.makeSkyWcs(
426 center, skyCenter, blankCDMatrix)
427 measurementToLocalGnomonic = wcs.getTransform().then(
428 localGnomonicWcs.getTransform().inverted()
429 )
430 localMatrix = measurementToLocalGnomonic.getJacobian(center)
431 return localMatrix / 3600
433 def makeAlertDict(self,
434 alertId,
435 diaSource,
436 diaObject,
437 objDiaSrcHistory,
438 objDiaForcedSources,
439 diffImCutout,
440 calexpCutout,
441 templateCutout):
442 """Convert data and package into a dictionary alert.
444 Parameters
445 ----------
446 diaSource : `pandas.DataFrame`
447 New single DiaSource to package.
448 diaObject : `pandas.DataFrame`
449 DiaObject that ``diaSource`` is matched to.
450 objDiaSrcHistory : `pandas.DataFrame`
451 12 month history of ``diaObject`` excluding the latest DiaSource.
452 objDiaForcedSources : `pandas.DataFrame`
453 12 month history of ``diaObject`` forced measurements.
454 diffImCutout : `astropy.nddata.CCDData` or `None`
455 Cutout of the difference image around the location of ``diaSource``
456 with a min size set by the ``cutoutSize`` configurable.
457 calexpCutout : `astropy.nddata.CCDData` or `None`
458 Cutout of the calexp around the location of ``diaSource``
459 with a min size set by the ``cutoutSize`` configurable.
460 templateCutout : `astropy.nddata.CCDData` or `None`
461 Cutout of the template image around the location of ``diaSource``
462 with a min size set by the ``cutoutSize`` configurable.
463 """
464 alert = dict()
465 alert['alertId'] = alertId
466 alert['diaSource'] = diaSource.to_dict()
468 if objDiaSrcHistory is None:
469 alert['prvDiaSources'] = objDiaSrcHistory
470 else:
471 alert['prvDiaSources'] = objDiaSrcHistory.to_dict("records")
473 if isinstance(objDiaForcedSources, pd.Series):
474 alert['prvDiaForcedSources'] = [objDiaForcedSources.to_dict()]
475 else:
476 alert['prvDiaForcedSources'] = objDiaForcedSources.to_dict("records")
477 alert['prvDiaNondetectionLimits'] = None
479 alert['diaObject'] = diaObject.to_dict()
481 alert['ssObject'] = None
483 if diffImCutout is None:
484 alert['cutoutDifference'] = None
485 else:
486 alert['cutoutDifference'] = self.streamCcdDataToBytes(diffImCutout)
488 if calexpCutout is None:
489 alert['cutoutScience'] = None
490 else:
491 alert['cutoutScience'] = self.streamCcdDataToBytes(calexpCutout)
493 if templateCutout is None:
494 alert["cutoutTemplate"] = None
495 else:
496 alert["cutoutTemplate"] = self.streamCcdDataToBytes(templateCutout)
498 return alert
500 def streamCcdDataToBytes(self, cutout):
501 """Serialize a cutout into bytes.
503 Parameters
504 ----------
505 cutout : `astropy.nddata.CCDData`
506 Cutout to serialize.
508 Returns
509 -------
510 coutputBytes : `bytes`
511 Input cutout serialized into byte data.
512 """
513 with io.BytesIO() as streamer:
514 cutout.write(streamer, format="fits")
515 cutoutBytes = streamer.getvalue()
516 return cutoutBytes
518 def _serializeAlert(self, alert, schema=None, schema_id=0):
519 """Serialize an alert to a byte sequence for sending to Kafka.
521 Parameters
522 ----------
523 alert : `dict`
524 An alert payload to be serialized.
525 schema : `dict`, optional
526 An Avro schema definition describing how to encode `alert`. By default,
527 the schema is None, which sets it to the latest schema available.
528 schema_id : `int`, optional
529 The Confluent Schema Registry ID of the schema. By default, 0 (an
530 invalid ID) is used, indicating that the schema is not registered.
532 Returns
533 -------
534 serialized : `bytes`
535 The byte sequence describing the alert, including the Confluent Wire
536 Format prefix.
537 """
538 if schema is None:
539 schema = self.alertSchema.definition
541 buf = io.BytesIO()
542 # TODO: Use a proper schema versioning system (DM-42606)
543 buf.write(self._serializeConfluentWireHeader(schema_id))
544 fastavro.schemaless_writer(buf, schema, alert)
545 return buf.getvalue()
547 @staticmethod
548 def _serializeConfluentWireHeader(schema_version):
549 """Returns the byte prefix for Confluent Wire Format-style Kafka messages.
551 Parameters
552 ----------
553 schema_version : `int`
554 A version number which indicates the Confluent Schema Registry ID
555 number of the Avro schema used to encode the message that follows this
556 header.
558 Returns
559 -------
560 header : `bytes`
561 The 5-byte encoded message prefix.
563 Notes
564 -----
565 The Confluent Wire Format is described more fully here:
566 https://docs.confluent.io/current/schema-registry/serdes-develop/index.html#wire-format
567 """
568 ConfluentWireFormatHeader = struct.Struct(">bi")
569 return ConfluentWireFormatHeader.pack(0, schema_version)
571 def _delivery_callback(self, err, msg):
572 if err:
573 self.log.warning('Message failed delivery: %s\n' % err)
574 else:
575 self.log.debug('Message delivered to %s [%d] @ %d', msg.topic(), msg.partition(), msg.offset())
577 def _server_check(self):
578 """Checks if the alert stream credentials are still valid and the
579 server is contactable.
581 Raises
582 -------
583 KafkaException
584 Raised if the server us not contactable.
585 RuntimeError
586 Raised if the server is contactable but there are no topics
587 present.
588 """
589 admin_client = AdminClient(self.kafkaAdminConfig)
590 topics = admin_client.list_topics(timeout=0.5).topics
592 if not topics:
593 raise RuntimeError()