Coverage for python/lsst/ap/association/packageAlerts.py: 22%
171 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-22 03:19 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-22 03:19 -0700
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ("PackageAlertsConfig", "PackageAlertsTask")
24import io
25import os
26import sys
28from astropy import wcs
29import astropy.units as u
30from astropy.nddata import CCDData, VarianceUncertainty
31import pandas as pd
32import struct
33import fastavro
34# confluent_kafka is not in the standard Rubin environment as it is a third
35# party package and is only needed when producing alerts.
36try:
37 import confluent_kafka
38 from confluent_kafka import KafkaException
39except ImportError:
40 confluent_kafka = None
42import lsst.alert.packet as alertPack
43import lsst.afw.geom as afwGeom
44import lsst.geom as geom
45import lsst.pex.config as pexConfig
46from lsst.pex.exceptions import InvalidParameterError
47import lsst.pipe.base as pipeBase
48from lsst.utils.timer import timeMethod
51class PackageAlertsConfig(pexConfig.Config):
52 """Config class for AssociationTask.
53 """
54 schemaFile = pexConfig.Field(
55 dtype=str,
56 doc="Schema definition file URI for the avro alerts.",
57 default=str(alertPack.get_uri_to_latest_schema())
58 )
59 minCutoutSize = pexConfig.RangeField(
60 dtype=int,
61 min=0,
62 max=1000,
63 default=30,
64 doc="Dimension of the square image cutouts to package in the alert."
65 )
66 alertWriteLocation = pexConfig.Field(
67 dtype=str,
68 doc="Location to write alerts to.",
69 default=os.path.join(os.getcwd(), "alerts"),
70 )
72 doProduceAlerts = pexConfig.Field(
73 dtype=bool,
74 doc="Turn on alert production to kafka if true and if confluent_kafka is in the environment.",
75 default=False,
76 )
78 doWriteAlerts = pexConfig.Field(
79 dtype=bool,
80 doc="Write alerts to disk if true.",
81 default=False,
82 )
85class PackageAlertsTask(pipeBase.Task):
86 """Tasks for packaging Dia and Pipelines data into Avro alert packages.
87 """
88 ConfigClass = PackageAlertsConfig
89 _DefaultName = "packageAlerts"
91 _scale = (1.0 * geom.arcseconds).asDegrees()
93 def __init__(self, **kwargs):
94 super().__init__(**kwargs)
95 self.alertSchema = alertPack.Schema.from_uri(self.config.schemaFile)
96 os.makedirs(self.config.alertWriteLocation, exist_ok=True)
98 if self.config.doProduceAlerts:
99 if confluent_kafka is not None:
100 self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD")
101 if not self.password:
102 raise ValueError("Kafka password environment variable was not set.")
103 self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME")
104 if not self.username:
105 raise ValueError("Kafka username environment variable was not set.")
106 self.server = os.getenv("AP_KAFKA_SERVER")
107 if not self.server:
108 raise ValueError("Kafka server environment variable was not set.")
109 self.kafkaTopic = os.getenv("AP_KAFKA_TOPIC")
110 if not self.kafkaTopic:
111 raise ValueError("Kafka topic environment variable was not set.")
113 # confluent_kafka configures all of its classes with dictionaries. This one
114 # sets up the bare minimum that is needed.
115 self.kafkaConfig = {
116 # This is the URL to use to connect to the Kafka cluster.
117 "bootstrap.servers": self.server,
118 # These next two properties tell the Kafka client about the specific
119 # authentication and authorization protocols that should be used when
120 # connecting.
121 "security.protocol": "SASL_PLAINTEXT",
122 "sasl.mechanisms": "SCRAM-SHA-512",
123 # The sasl.username and sasl.password are passed through over
124 # SCRAM-SHA-512 auth to connect to the cluster. The username is not
125 # sensitive, but the password is (of course) a secret value which
126 # should never be committed to source code.
127 "sasl.username": self.username,
128 "sasl.password": self.password,
129 # Batch size limits the largest size of a kafka alert that can be sent.
130 # We set the batch size to 2 Mb.
131 "batch.size": 2097152,
132 "linger.ms": 5,
133 }
134 self.producer = confluent_kafka.Producer(**self.kafkaConfig)
136 else:
137 raise RuntimeError("Produce alerts is set but confluent_kafka is not present in "
138 "the environment. Alerts will not be sent to the alert stream.")
140 @timeMethod
141 def run(self,
142 diaSourceCat,
143 diaObjectCat,
144 diaSrcHistory,
145 diaForcedSources,
146 diffIm,
147 calexp,
148 template,
149 doRunForcedMeasurement=True,
150 ):
151 """Package DiaSources/Object and exposure data into Avro alerts.
153 Alerts can be sent to the alert stream if ``doProduceAlerts`` is set
154 and written to disk if ``doWriteAlerts`` is set. Both can be set at the
155 same time, and are independent of one another.
157 Writes Avro alerts to a location determined by the
158 ``alertWriteLocation`` configurable.
160 Parameters
161 ----------
162 diaSourceCat : `pandas.DataFrame`
163 New DiaSources to package. DataFrame should be indexed on
164 ``["diaObjectId", "band", "diaSourceId"]``
165 diaObjectCat : `pandas.DataFrame`
166 New and updated DiaObjects matched to the new DiaSources. DataFrame
167 is indexed on ``["diaObjectId"]``
168 diaSrcHistory : `pandas.DataFrame`
169 12 month history of DiaSources matched to the DiaObjects. Excludes
170 the newest DiaSource and is indexed on
171 ``["diaObjectId", "band", "diaSourceId"]``
172 diaForcedSources : `pandas.DataFrame`
173 12 month history of DiaForcedSources matched to the DiaObjects.
174 ``["diaObjectId"]``
175 diffIm : `lsst.afw.image.ExposureF`
176 Difference image the sources in ``diaSourceCat`` were detected in.
177 calexp : `lsst.afw.image.ExposureF`
178 Calexp used to create the ``diffIm``.
179 template : `lsst.afw.image.ExposureF` or `None`
180 Template image used to create the ``diffIm``.
181 doRunForcedMeasurement : `bool`, optional
182 Flag to indicate whether forced measurement was run.
183 This should only be turned off for debugging purposes.
184 Added to allow disabling forced sources for performance
185 reasons during the ops rehearsal.
186 """
187 alerts = []
188 self._patchDiaSources(diaSourceCat)
189 self._patchDiaSources(diaSrcHistory)
190 ccdVisitId = diffIm.info.id
191 diffImPhotoCalib = diffIm.getPhotoCalib()
192 calexpPhotoCalib = calexp.getPhotoCalib()
193 templatePhotoCalib = template.getPhotoCalib()
194 for srcIndex, diaSource in diaSourceCat.iterrows():
195 # Get all diaSources for the associated diaObject.
196 # TODO: DM-31992 skip DiaSources associated with Solar System
197 # Objects for now.
198 if srcIndex[0] == 0:
199 continue
200 diaObject = diaObjectCat.loc[srcIndex[0]]
201 if diaObject["nDiaSources"] > 1:
202 objSourceHistory = diaSrcHistory.loc[srcIndex[0]]
203 else:
204 objSourceHistory = None
205 if doRunForcedMeasurement:
206 objDiaForcedSources = diaForcedSources.loc[srcIndex[0]]
207 else:
208 # Send empty table with correct columns
209 objDiaForcedSources = diaForcedSources.loc[[]]
210 sphPoint = geom.SpherePoint(diaSource["ra"],
211 diaSource["dec"],
212 geom.degrees)
214 cutoutExtent = self.createDiaSourceExtent(diaSource["bboxSize"])
215 diffImCutout = self.createCcdDataCutout(
216 diffIm,
217 sphPoint,
218 cutoutExtent,
219 diffImPhotoCalib,
220 diaSource["diaSourceId"])
221 calexpCutout = self.createCcdDataCutout(
222 calexp,
223 sphPoint,
224 cutoutExtent,
225 calexpPhotoCalib,
226 diaSource["diaSourceId"])
227 templateCutout = self.createCcdDataCutout(
228 template,
229 sphPoint,
230 cutoutExtent,
231 templatePhotoCalib,
232 diaSource["diaSourceId"])
234 # TODO: Create alertIds DM-24858
235 alertId = diaSource["diaSourceId"]
236 alerts.append(
237 self.makeAlertDict(alertId,
238 diaSource,
239 diaObject,
240 objSourceHistory,
241 objDiaForcedSources,
242 diffImCutout,
243 calexpCutout,
244 templateCutout))
246 if self.config.doProduceAlerts:
247 self.produceAlerts(alerts, ccdVisitId)
249 if self.config.doWriteAlerts:
250 with open(os.path.join(self.config.alertWriteLocation, f"{ccdVisitId}.avro"), "wb") as f:
251 self.alertSchema.store_alerts(f, alerts)
253 def _patchDiaSources(self, diaSources):
254 """Add the ``programId`` column to the data.
256 Parameters
257 ----------
258 diaSources : `pandas.DataFrame`
259 DataFrame of DiaSources to patch.
260 """
261 diaSources["programId"] = 0
263 def createDiaSourceExtent(self, bboxSize):
264 """Create an extent for a box for the cutouts given the size of the
265 square BBox that covers the source footprint.
267 Parameters
268 ----------
269 bboxSize : `int`
270 Size of a side of the square bounding box in pixels.
272 Returns
273 -------
274 extent : `lsst.geom.Extent2I`
275 Geom object representing the size of the bounding box.
276 """
277 if bboxSize < self.config.minCutoutSize:
278 extent = geom.Extent2I(self.config.minCutoutSize,
279 self.config.minCutoutSize)
280 else:
281 extent = geom.Extent2I(bboxSize, bboxSize)
282 return extent
284 def produceAlerts(self, alerts, ccdVisitId):
285 """Serialize alerts and send them to the alert stream using
286 confluent_kafka's producer.
288 Parameters
289 ----------
290 alerts : `dict`
291 Dictionary of alerts to be sent to the alert stream.
292 ccdVisitId : `int`
293 ccdVisitId of the alerts sent to the alert stream. Used to write
294 out alerts which fail to be sent to the alert stream.
295 """
296 for alert in alerts:
297 alertBytes = self._serializeAlert(alert, schema=self.alertSchema.definition, schema_id=1)
298 try:
299 self.producer.produce(self.kafkaTopic, alertBytes, callback=self._delivery_callback)
300 self.producer.flush()
302 except KafkaException as e:
303 self.log.warning('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alertBytes)))
305 with open(os.path.join(self.config.alertWriteLocation,
306 f"{ccdVisitId}_{alert['alertId']}.avro"), "wb") as f:
307 f.write(alertBytes)
309 self.producer.flush()
311 def createCcdDataCutout(self, image, skyCenter, extent, photoCalib, srcId):
312 """Grab an image as a cutout and return a calibrated CCDData image.
314 Parameters
315 ----------
316 image : `lsst.afw.image.ExposureF`
317 Image to pull cutout from.
318 skyCenter : `lsst.geom.SpherePoint`
319 Center point of DiaSource on the sky.
320 extent : `lsst.geom.Extent2I`
321 Bounding box to cutout from the image.
322 photoCalib : `lsst.afw.image.PhotoCalib`
323 Calibrate object of the image the cutout is cut from.
324 srcId : `int`
325 Unique id of DiaSource. Used for when an error occurs extracting
326 a cutout.
328 Returns
329 -------
330 ccdData : `astropy.nddata.CCDData` or `None`
331 CCDData object storing the calibrate information from the input
332 difference or template image.
333 """
334 # Catch errors in retrieving the cutout.
335 try:
336 cutout = image.getCutout(skyCenter, extent)
337 except InvalidParameterError:
338 point = image.getWcs().skyToPixel(skyCenter)
339 imBBox = image.getBBox()
340 if not geom.Box2D(image.getBBox()).contains(point):
341 self.log.warning(
342 "DiaSource id=%i centroid lies at pixel (%.2f, %.2f) "
343 "which is outside the Exposure with bounding box "
344 "((%i, %i), (%i, %i)). Returning None for cutout...",
345 srcId, point.x, point.y,
346 imBBox.minX, imBBox.maxX, imBBox.minY, imBBox.maxY)
347 else:
348 raise InvalidParameterError(
349 "Failed to retrieve cutout from image for DiaSource with "
350 "id=%i. InvalidParameterError thrown during cutout "
351 "creation. Exiting."
352 % srcId)
353 return None
355 # Find the value of the bottom corner of our cutout's BBox and
356 # subtract 1 so that the CCDData cutout position value will be
357 # [1, 1].
358 cutOutMinX = cutout.getBBox().minX - 1
359 cutOutMinY = cutout.getBBox().minY - 1
360 center = cutout.getWcs().skyToPixel(skyCenter)
361 calibCutout = photoCalib.calibrateImage(cutout.getMaskedImage())
363 cutoutWcs = wcs.WCS(naxis=2)
364 cutoutWcs.array_shape = (cutout.getBBox().getWidth(),
365 cutout.getBBox().getWidth())
366 cutoutWcs.wcs.crpix = [center.x - cutOutMinX, center.y - cutOutMinY]
367 cutoutWcs.wcs.crval = [skyCenter.getRa().asDegrees(),
368 skyCenter.getDec().asDegrees()]
369 cutoutWcs.wcs.cd = self.makeLocalTransformMatrix(cutout.getWcs(),
370 center,
371 skyCenter)
373 return CCDData(
374 data=calibCutout.getImage().array,
375 uncertainty=VarianceUncertainty(calibCutout.getVariance().array),
376 flags=calibCutout.getMask().array,
377 wcs=cutoutWcs,
378 meta={"cutMinX": cutOutMinX,
379 "cutMinY": cutOutMinY},
380 unit=u.nJy)
382 def makeLocalTransformMatrix(self, wcs, center, skyCenter):
383 """Create a local, linear approximation of the wcs transformation
384 matrix.
386 The approximation is created as if the center is at RA=0, DEC=0. All
387 comparing x,y coordinate are relative to the position of center. Matrix
388 is initially calculated with units arcseconds and then converted to
389 degrees. This yields higher precision results due to quirks in AST.
391 Parameters
392 ----------
393 wcs : `lsst.afw.geom.SkyWcs`
394 Wcs to approximate
395 center : `lsst.geom.Point2D`
396 Point at which to evaluate the LocalWcs.
397 skyCenter : `lsst.geom.SpherePoint`
398 Point on sky to approximate the Wcs.
400 Returns
401 -------
402 localMatrix : `numpy.ndarray`
403 Matrix representation the local wcs approximation with units
404 degrees.
405 """
406 blankCDMatrix = [[self._scale, 0], [0, self._scale]]
407 localGnomonicWcs = afwGeom.makeSkyWcs(
408 center, skyCenter, blankCDMatrix)
409 measurementToLocalGnomonic = wcs.getTransform().then(
410 localGnomonicWcs.getTransform().inverted()
411 )
412 localMatrix = measurementToLocalGnomonic.getJacobian(center)
413 return localMatrix / 3600
415 def makeAlertDict(self,
416 alertId,
417 diaSource,
418 diaObject,
419 objDiaSrcHistory,
420 objDiaForcedSources,
421 diffImCutout,
422 calexpCutout,
423 templateCutout):
424 """Convert data and package into a dictionary alert.
426 Parameters
427 ----------
428 diaSource : `pandas.DataFrame`
429 New single DiaSource to package.
430 diaObject : `pandas.DataFrame`
431 DiaObject that ``diaSource`` is matched to.
432 objDiaSrcHistory : `pandas.DataFrame`
433 12 month history of ``diaObject`` excluding the latest DiaSource.
434 objDiaForcedSources : `pandas.DataFrame`
435 12 month history of ``diaObject`` forced measurements.
436 diffImCutout : `astropy.nddata.CCDData` or `None`
437 Cutout of the difference image around the location of ``diaSource``
438 with a min size set by the ``cutoutSize`` configurable.
439 calexpCutout : `astropy.nddata.CCDData` or `None`
440 Cutout of the calexp around the location of ``diaSource``
441 with a min size set by the ``cutoutSize`` configurable.
442 templateCutout : `astropy.nddata.CCDData` or `None`
443 Cutout of the template image around the location of ``diaSource``
444 with a min size set by the ``cutoutSize`` configurable.
445 """
446 alert = dict()
447 alert['alertId'] = alertId
448 alert['diaSource'] = diaSource.to_dict()
450 if objDiaSrcHistory is None:
451 alert['prvDiaSources'] = objDiaSrcHistory
452 else:
453 alert['prvDiaSources'] = objDiaSrcHistory.to_dict("records")
455 if isinstance(objDiaForcedSources, pd.Series):
456 alert['prvDiaForcedSources'] = [objDiaForcedSources.to_dict()]
457 else:
458 alert['prvDiaForcedSources'] = objDiaForcedSources.to_dict("records")
459 alert['prvDiaNondetectionLimits'] = None
461 alert['diaObject'] = diaObject.to_dict()
463 alert['ssObject'] = None
465 if diffImCutout is None:
466 alert['cutoutDifference'] = None
467 else:
468 alert['cutoutDifference'] = self.streamCcdDataToBytes(diffImCutout)
470 if calexpCutout is None:
471 alert['cutoutScience'] = None
472 else:
473 alert['cutoutScience'] = self.streamCcdDataToBytes(calexpCutout)
475 if templateCutout is None:
476 alert["cutoutTemplate"] = None
477 else:
478 alert["cutoutTemplate"] = self.streamCcdDataToBytes(templateCutout)
480 return alert
482 def streamCcdDataToBytes(self, cutout):
483 """Serialize a cutout into bytes.
485 Parameters
486 ----------
487 cutout : `astropy.nddata.CCDData`
488 Cutout to serialize.
490 Returns
491 -------
492 coutputBytes : `bytes`
493 Input cutout serialized into byte data.
494 """
495 with io.BytesIO() as streamer:
496 cutout.write(streamer, format="fits")
497 cutoutBytes = streamer.getvalue()
498 return cutoutBytes
500 def _serializeAlert(self, alert, schema=None, schema_id=0):
501 """Serialize an alert to a byte sequence for sending to Kafka.
503 Parameters
504 ----------
505 alert : `dict`
506 An alert payload to be serialized.
507 schema : `dict`, optional
508 An Avro schema definition describing how to encode `alert`. By default,
509 the schema is None, which sets it to the latest schema available.
510 schema_id : `int`, optional
511 The Confluent Schema Registry ID of the schema. By default, 0 (an
512 invalid ID) is used, indicating that the schema is not registered.
514 Returns
515 -------
516 serialized : `bytes`
517 The byte sequence describing the alert, including the Confluent Wire
518 Format prefix.
519 """
520 if schema is None:
521 schema = self.alertSchema.definition
523 buf = io.BytesIO()
524 # TODO: Use a proper schema versioning system (DM-42606)
525 buf.write(self._serializeConfluentWireHeader(schema_id))
526 fastavro.schemaless_writer(buf, schema, alert)
527 return buf.getvalue()
529 @staticmethod
530 def _serializeConfluentWireHeader(schema_version):
531 """Returns the byte prefix for Confluent Wire Format-style Kafka messages.
533 Parameters
534 ----------
535 schema_version : `int`
536 A version number which indicates the Confluent Schema Registry ID
537 number of the Avro schema used to encode the message that follows this
538 header.
540 Returns
541 -------
542 header : `bytes`
543 The 5-byte encoded message prefix.
545 Notes
546 -----
547 The Confluent Wire Format is described more fully here:
548 https://docs.confluent.io/current/schema-registry/serdes-develop/index.html#wire-format
549 """
550 ConfluentWireFormatHeader = struct.Struct(">bi")
551 return ConfluentWireFormatHeader.pack(0, schema_version)
553 def _delivery_callback(self, err, msg):
554 if err:
555 self.log.warning('Message failed delivery: %s\n' % err)
556 else:
557 self.log.debug('Message delivered to %s [%d] @ %d', msg.topic(), msg.partition(), msg.offset())