Coverage for python/lsst/ap/association/packageAlerts.py: 23%
164 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-14 11:51 -0700
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-14 11:51 -0700
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ("PackageAlertsConfig", "PackageAlertsTask")
24import io
25import os
26import sys
28from astropy import wcs
29import astropy.units as u
30from astropy.nddata import CCDData, VarianceUncertainty
31import pandas as pd
32import struct
33import fastavro
34# confluent_kafka is not in the standard Rubin environment as it is a third
35# party package and is only needed when producing alerts.
36try:
37 import confluent_kafka
38 from confluent_kafka import KafkaException
39except ImportError:
40 confluent_kafka = None
42import lsst.alert.packet as alertPack
43import lsst.afw.geom as afwGeom
44import lsst.geom as geom
45import lsst.pex.config as pexConfig
46from lsst.pex.exceptions import InvalidParameterError
47import lsst.pipe.base as pipeBase
48from lsst.utils.timer import timeMethod
51class PackageAlertsConfig(pexConfig.Config):
52 """Config class for AssociationTask.
53 """
54 schemaFile = pexConfig.Field(
55 dtype=str,
56 doc="Schema definition file URI for the avro alerts.",
57 default=str(alertPack.get_uri_to_latest_schema())
58 )
59 minCutoutSize = pexConfig.RangeField(
60 dtype=int,
61 min=0,
62 max=1000,
63 default=30,
64 doc="Dimension of the square image cutouts to package in the alert."
65 )
66 alertWriteLocation = pexConfig.Field(
67 dtype=str,
68 doc="Location to write alerts to.",
69 default=os.path.join(os.getcwd(), "alerts"),
70 )
72 doProduceAlerts = pexConfig.Field(
73 dtype=bool,
74 doc="Turn on alert production to kafka if true and if confluent_kafka is in the environment.",
75 default=False,
76 )
78 doWriteAlerts = pexConfig.Field(
79 dtype=bool,
80 doc="Write alerts to disk if true.",
81 default=False,
82 )
85class PackageAlertsTask(pipeBase.Task):
86 """Tasks for packaging Dia and Pipelines data into Avro alert packages.
87 """
88 ConfigClass = PackageAlertsConfig
89 _DefaultName = "packageAlerts"
91 _scale = (1.0 * geom.arcseconds).asDegrees()
93 def __init__(self, **kwargs):
94 super().__init__(**kwargs)
95 self.alertSchema = alertPack.Schema.from_uri(self.config.schemaFile)
96 os.makedirs(self.config.alertWriteLocation, exist_ok=True)
98 if self.config.doProduceAlerts:
99 if confluent_kafka is not None:
100 self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD")
101 if not self.password:
102 raise ValueError("Kafka password environment variable was not set.")
103 self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME")
104 if not self.username:
105 raise ValueError("Kafka username environment variable was not set.")
106 self.server = os.getenv("AP_KAFKA_SERVER")
107 if not self.server:
108 raise ValueError("Kafka server environment variable was not set.")
109 self.kafkaTopic = os.getenv("AP_KAFKA_TOPIC")
110 if not self.kafkaTopic:
111 raise ValueError("Kafka topic environment variable was not set.")
113 # confluent_kafka configures all of its classes with dictionaries. This one
114 # sets up the bare minimum that is needed.
115 self.kafkaConfig = {
116 # This is the URL to use to connect to the Kafka cluster.
117 "bootstrap.servers": self.server,
118 # These next two properties tell the Kafka client about the specific
119 # authentication and authorization protocols that should be used when
120 # connecting.
121 "security.protocol": "SASL_PLAINTEXT",
122 "sasl.mechanisms": "SCRAM-SHA-512",
123 # The sasl.username and sasl.password are passed through over
124 # SCRAM-SHA-512 auth to connect to the cluster. The username is not
125 # sensitive, but the password is (of course) a secret value which
126 # should never be committed to source code.
127 "sasl.username": self.username,
128 "sasl.password": self.password,
129 # Batch size limits the largest size of a kafka alert that can be sent.
130 # We set the batch size to 2 Mb.
131 "batch.size": 2097152,
132 "linger.ms": 5,
133 }
134 self.producer = confluent_kafka.Producer(**self.kafkaConfig)
136 else:
137 raise RuntimeError("Produce alerts is set but confluent_kafka is not present in "
138 "the environment. Alerts will not be sent to the alert stream.")
140 @timeMethod
141 def run(self,
142 diaSourceCat,
143 diaObjectCat,
144 diaSrcHistory,
145 diaForcedSources,
146 diffIm,
147 template,
148 ):
149 """Package DiaSources/Object and exposure data into Avro alerts.
151 Alerts can be sent to the alert stream if ``doProduceAlerts`` is set
152 and written to disk if ``doWriteAlerts`` is set. Both can be set at the
153 same time, and are independent of one another.
155 Writes Avro alerts to a location determined by the
156 ``alertWriteLocation`` configurable.
158 Parameters
159 ----------
160 diaSourceCat : `pandas.DataFrame`
161 New DiaSources to package. DataFrame should be indexed on
162 ``["diaObjectId", "band", "diaSourceId"]``
163 diaObjectCat : `pandas.DataFrame`
164 New and updated DiaObjects matched to the new DiaSources. DataFrame
165 is indexed on ``["diaObjectId"]``
166 diaSrcHistory : `pandas.DataFrame`
167 12 month history of DiaSources matched to the DiaObjects. Excludes
168 the newest DiaSource and is indexed on
169 ``["diaObjectId", "band", "diaSourceId"]``
170 diaForcedSources : `pandas.DataFrame`
171 12 month history of DiaForcedSources matched to the DiaObjects.
172 ``["diaObjectId"]``
173 diffIm : `lsst.afw.image.ExposureF`
174 Difference image the sources in ``diaSourceCat`` were detected in.
175 template : `lsst.afw.image.ExposureF` or `None`
176 Template image used to create the ``diffIm``.
177 """
178 alerts = []
179 self._patchDiaSources(diaSourceCat)
180 self._patchDiaSources(diaSrcHistory)
181 ccdVisitId = diffIm.info.id
182 diffImPhotoCalib = diffIm.getPhotoCalib()
183 templatePhotoCalib = template.getPhotoCalib()
184 for srcIndex, diaSource in diaSourceCat.iterrows():
185 # Get all diaSources for the associated diaObject.
186 # TODO: DM-31992 skip DiaSources associated with Solar System
187 # Objects for now.
188 if srcIndex[0] == 0:
189 continue
190 diaObject = diaObjectCat.loc[srcIndex[0]]
191 if diaObject["nDiaSources"] > 1:
192 objSourceHistory = diaSrcHistory.loc[srcIndex[0]]
193 else:
194 objSourceHistory = None
195 objDiaForcedSources = diaForcedSources.loc[srcIndex[0]]
196 sphPoint = geom.SpherePoint(diaSource["ra"],
197 diaSource["dec"],
198 geom.degrees)
200 cutoutExtent = self.createDiaSourceExtent(diaSource["bboxSize"])
201 diffImCutout = self.createCcdDataCutout(
202 diffIm,
203 sphPoint,
204 cutoutExtent,
205 diffImPhotoCalib,
206 diaSource["diaSourceId"])
207 templateCutout = self.createCcdDataCutout(
208 template,
209 sphPoint,
210 cutoutExtent,
211 templatePhotoCalib,
212 diaSource["diaSourceId"])
214 # TODO: Create alertIds DM-24858
215 alertId = diaSource["diaSourceId"]
216 alerts.append(
217 self.makeAlertDict(alertId,
218 diaSource,
219 diaObject,
220 objSourceHistory,
221 objDiaForcedSources,
222 diffImCutout,
223 templateCutout))
225 if self.config.doProduceAlerts:
226 self.produceAlerts(alerts, ccdVisitId)
228 if self.config.doWriteAlerts:
229 with open(os.path.join(self.config.alertWriteLocation, f"{ccdVisitId}.avro"), "wb") as f:
230 self.alertSchema.store_alerts(f, alerts)
232 def _patchDiaSources(self, diaSources):
233 """Add the ``programId`` column to the data.
235 Parameters
236 ----------
237 diaSources : `pandas.DataFrame`
238 DataFrame of DiaSources to patch.
239 """
240 diaSources["programId"] = 0
242 def createDiaSourceExtent(self, bboxSize):
243 """Create an extent for a box for the cutouts given the size of the
244 square BBox that covers the source footprint.
246 Parameters
247 ----------
248 bboxSize : `int`
249 Size of a side of the square bounding box in pixels.
251 Returns
252 -------
253 extent : `lsst.geom.Extent2I`
254 Geom object representing the size of the bounding box.
255 """
256 if bboxSize < self.config.minCutoutSize:
257 extent = geom.Extent2I(self.config.minCutoutSize,
258 self.config.minCutoutSize)
259 else:
260 extent = geom.Extent2I(bboxSize, bboxSize)
261 return extent
263 def produceAlerts(self, alerts, ccdVisitId):
264 """Serialize alerts and send them to the alert stream using
265 confluent_kafka's producer.
267 Parameters
268 ----------
269 alerts : `dict`
270 Dictionary of alerts to be sent to the alert stream.
271 ccdVisitId : `int`
272 ccdVisitId of the alerts sent to the alert stream. Used to write
273 out alerts which fail to be sent to the alert stream.
274 """
275 for alert in alerts:
276 alertBytes = self._serializeAlert(alert, schema=self.alertSchema.definition, schema_id=1)
277 try:
278 self.producer.produce(self.kafkaTopic, alertBytes, callback=self._delivery_callback)
279 self.producer.flush()
281 except KafkaException as e:
282 self.log.warning('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alertBytes)))
284 with open(os.path.join(self.config.alertWriteLocation,
285 f"{ccdVisitId}_{alert['alertId']}.avro"), "wb") as f:
286 f.write(alertBytes)
288 self.producer.flush()
290 def createCcdDataCutout(self, image, skyCenter, extent, photoCalib, srcId):
291 """Grab an image as a cutout and return a calibrated CCDData image.
293 Parameters
294 ----------
295 image : `lsst.afw.image.ExposureF`
296 Image to pull cutout from.
297 skyCenter : `lsst.geom.SpherePoint`
298 Center point of DiaSource on the sky.
299 extent : `lsst.geom.Extent2I`
300 Bounding box to cutout from the image.
301 photoCalib : `lsst.afw.image.PhotoCalib`
302 Calibrate object of the image the cutout is cut from.
303 srcId : `int`
304 Unique id of DiaSource. Used for when an error occurs extracting
305 a cutout.
307 Returns
308 -------
309 ccdData : `astropy.nddata.CCDData` or `None`
310 CCDData object storing the calibrate information from the input
311 difference or template image.
312 """
313 # Catch errors in retrieving the cutout.
314 try:
315 cutout = image.getCutout(skyCenter, extent)
316 except InvalidParameterError:
317 point = image.getWcs().skyToPixel(skyCenter)
318 imBBox = image.getBBox()
319 if not geom.Box2D(image.getBBox()).contains(point):
320 self.log.warning(
321 "DiaSource id=%i centroid lies at pixel (%.2f, %.2f) "
322 "which is outside the Exposure with bounding box "
323 "((%i, %i), (%i, %i)). Returning None for cutout...",
324 srcId, point.x, point.y,
325 imBBox.minX, imBBox.maxX, imBBox.minY, imBBox.maxY)
326 else:
327 raise InvalidParameterError(
328 "Failed to retrieve cutout from image for DiaSource with "
329 "id=%i. InvalidParameterError thrown during cutout "
330 "creation. Exiting."
331 % srcId)
332 return None
334 # Find the value of the bottom corner of our cutout's BBox and
335 # subtract 1 so that the CCDData cutout position value will be
336 # [1, 1].
337 cutOutMinX = cutout.getBBox().minX - 1
338 cutOutMinY = cutout.getBBox().minY - 1
339 center = cutout.getWcs().skyToPixel(skyCenter)
340 calibCutout = photoCalib.calibrateImage(cutout.getMaskedImage())
342 cutoutWcs = wcs.WCS(naxis=2)
343 cutoutWcs.array_shape = (cutout.getBBox().getWidth(),
344 cutout.getBBox().getWidth())
345 cutoutWcs.wcs.crpix = [center.x - cutOutMinX, center.y - cutOutMinY]
346 cutoutWcs.wcs.crval = [skyCenter.getRa().asDegrees(),
347 skyCenter.getDec().asDegrees()]
348 cutoutWcs.wcs.cd = self.makeLocalTransformMatrix(cutout.getWcs(),
349 center,
350 skyCenter)
352 return CCDData(
353 data=calibCutout.getImage().array,
354 uncertainty=VarianceUncertainty(calibCutout.getVariance().array),
355 flags=calibCutout.getMask().array,
356 wcs=cutoutWcs,
357 meta={"cutMinX": cutOutMinX,
358 "cutMinY": cutOutMinY},
359 unit=u.nJy)
361 def makeLocalTransformMatrix(self, wcs, center, skyCenter):
362 """Create a local, linear approximation of the wcs transformation
363 matrix.
365 The approximation is created as if the center is at RA=0, DEC=0. All
366 comparing x,y coordinate are relative to the position of center. Matrix
367 is initially calculated with units arcseconds and then converted to
368 degrees. This yields higher precision results due to quirks in AST.
370 Parameters
371 ----------
372 wcs : `lsst.afw.geom.SkyWcs`
373 Wcs to approximate
374 center : `lsst.geom.Point2D`
375 Point at which to evaluate the LocalWcs.
376 skyCenter : `lsst.geom.SpherePoint`
377 Point on sky to approximate the Wcs.
379 Returns
380 -------
381 localMatrix : `numpy.ndarray`
382 Matrix representation the local wcs approximation with units
383 degrees.
384 """
385 blankCDMatrix = [[self._scale, 0], [0, self._scale]]
386 localGnomonicWcs = afwGeom.makeSkyWcs(
387 center, skyCenter, blankCDMatrix)
388 measurementToLocalGnomonic = wcs.getTransform().then(
389 localGnomonicWcs.getTransform().inverted()
390 )
391 localMatrix = measurementToLocalGnomonic.getJacobian(center)
392 return localMatrix / 3600
394 def makeAlertDict(self,
395 alertId,
396 diaSource,
397 diaObject,
398 objDiaSrcHistory,
399 objDiaForcedSources,
400 diffImCutout,
401 templateCutout):
402 """Convert data and package into a dictionary alert.
404 Parameters
405 ----------
406 diaSource : `pandas.DataFrame`
407 New single DiaSource to package.
408 diaObject : `pandas.DataFrame`
409 DiaObject that ``diaSource`` is matched to.
410 objDiaSrcHistory : `pandas.DataFrame`
411 12 month history of ``diaObject`` excluding the latest DiaSource.
412 objDiaForcedSources : `pandas.DataFrame`
413 12 month history of ``diaObject`` forced measurements.
414 diffImCutout : `astropy.nddata.CCDData` or `None`
415 Cutout of the difference image around the location of ``diaSource``
416 with a min size set by the ``cutoutSize`` configurable.
417 templateCutout : `astropy.nddata.CCDData` or `None`
418 Cutout of the template image around the location of ``diaSource``
419 with a min size set by the ``cutoutSize`` configurable.
420 """
421 alert = dict()
422 alert['alertId'] = alertId
423 alert['diaSource'] = diaSource.to_dict()
425 if objDiaSrcHistory is None:
426 alert['prvDiaSources'] = objDiaSrcHistory
427 else:
428 alert['prvDiaSources'] = objDiaSrcHistory.to_dict("records")
430 if isinstance(objDiaForcedSources, pd.Series):
431 alert['prvDiaForcedSources'] = [objDiaForcedSources.to_dict()]
432 else:
433 alert['prvDiaForcedSources'] = objDiaForcedSources.to_dict("records")
434 alert['prvDiaNondetectionLimits'] = None
436 alert['diaObject'] = diaObject.to_dict()
438 alert['ssObject'] = None
440 if diffImCutout is None:
441 alert['cutoutDifference'] = None
442 else:
443 alert['cutoutDifference'] = self.streamCcdDataToBytes(diffImCutout)
445 if templateCutout is None:
446 alert["cutoutTemplate"] = None
447 else:
448 alert["cutoutTemplate"] = self.streamCcdDataToBytes(templateCutout)
450 return alert
452 def streamCcdDataToBytes(self, cutout):
453 """Serialize a cutout into bytes.
455 Parameters
456 ----------
457 cutout : `astropy.nddata.CCDData`
458 Cutout to serialize.
460 Returns
461 -------
462 coutputBytes : `bytes`
463 Input cutout serialized into byte data.
464 """
465 with io.BytesIO() as streamer:
466 cutout.write(streamer, format="fits")
467 cutoutBytes = streamer.getvalue()
468 return cutoutBytes
470 def _serializeAlert(self, alert, schema=None, schema_id=0):
471 """Serialize an alert to a byte sequence for sending to Kafka.
473 Parameters
474 ----------
475 alert : `dict`
476 An alert payload to be serialized.
477 schema : `dict`, optional
478 An Avro schema definition describing how to encode `alert`. By default,
479 the schema is None, which sets it to the latest schema available.
480 schema_id : `int`, optional
481 The Confluent Schema Registry ID of the schema. By default, 0 (an
482 invalid ID) is used, indicating that the schema is not registered.
484 Returns
485 -------
486 serialized : `bytes`
487 The byte sequence describing the alert, including the Confluent Wire
488 Format prefix.
489 """
490 if schema is None:
491 schema = self.alertSchema.definition
493 buf = io.BytesIO()
494 # TODO: Use a proper schema versioning system (DM-42606)
495 buf.write(self._serializeConfluentWireHeader(schema_id))
496 fastavro.schemaless_writer(buf, schema, alert)
497 return buf.getvalue()
499 @staticmethod
500 def _serializeConfluentWireHeader(schema_version):
501 """Returns the byte prefix for Confluent Wire Format-style Kafka messages.
503 Parameters
504 ----------
505 schema_version : `int`
506 A version number which indicates the Confluent Schema Registry ID
507 number of the Avro schema used to encode the message that follows this
508 header.
510 Returns
511 -------
512 header : `bytes`
513 The 5-byte encoded message prefix.
515 Notes
516 -----
517 The Confluent Wire Format is described more fully here:
518 https://docs.confluent.io/current/schema-registry/serdes-develop/index.html#wire-format
519 """
520 ConfluentWireFormatHeader = struct.Struct(">bi")
521 return ConfluentWireFormatHeader.pack(0, schema_version)
523 def _delivery_callback(self, err, msg):
524 if err:
525 self.log.warning('Message failed delivery: %s\n' % err)
526 else:
527 self.log.debug('Message delivered to %s [%d] @ %d', msg.topic(), msg.partition(), msg.offset())