Coverage for python/lsst/ap/association/packageAlerts.py: 26%
109 statements
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-22 12:38 +0000
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-22 12:38 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ("PackageAlertsConfig", "PackageAlertsTask")
24import io
25import os
27from astropy import wcs
28import astropy.units as u
29from astropy.nddata import CCDData, VarianceUncertainty
30import pandas as pd
32import lsst.alert.packet as alertPack
33import lsst.afw.geom as afwGeom
34import lsst.geom as geom
35import lsst.pex.config as pexConfig
36from lsst.pex.exceptions import InvalidParameterError
37import lsst.pipe.base as pipeBase
38from lsst.utils.timer import timeMethod
41"""Methods for packaging Apdb and Pipelines data into Avro alerts.
42"""
45class PackageAlertsConfig(pexConfig.Config):
46 """Config class for AssociationTask.
47 """
48 schemaFile = pexConfig.Field(
49 dtype=str,
50 doc="Schema definition file URI for the avro alerts.",
51 default=str(alertPack.get_uri_to_latest_schema())
52 )
53 minCutoutSize = pexConfig.RangeField(
54 dtype=int,
55 min=0,
56 max=1000,
57 default=30,
58 doc="Dimension of the square image cutouts to package in the alert."
59 )
60 alertWriteLocation = pexConfig.Field(
61 dtype=str,
62 doc="Location to write alerts to.",
63 default=os.path.join(os.getcwd(), "alerts"),
64 )
67class PackageAlertsTask(pipeBase.Task):
68 """Tasks for packaging Dia and Pipelines data into Avro alert packages.
69 """
70 ConfigClass = PackageAlertsConfig
71 _DefaultName = "packageAlerts"
73 _scale = (1.0 * geom.arcseconds).asDegrees()
75 def __init__(self, **kwargs):
76 super().__init__(**kwargs)
77 self.alertSchema = alertPack.Schema.from_uri(self.config.schemaFile)
78 os.makedirs(self.config.alertWriteLocation, exist_ok=True)
80 @timeMethod
81 def run(self,
82 diaSourceCat,
83 diaObjectCat,
84 diaSrcHistory,
85 diaForcedSources,
86 diffIm,
87 template,
88 ):
89 """Package DiaSources/Object and exposure data into Avro alerts.
91 Writes Avro alerts to a location determined by the
92 ``alertWriteLocation`` configurable.
94 Parameters
95 ----------
96 diaSourceCat : `pandas.DataFrame`
97 New DiaSources to package. DataFrame should be indexed on
98 ``["diaObjectId", "band", "diaSourceId"]``
99 diaObjectCat : `pandas.DataFrame`
100 New and updated DiaObjects matched to the new DiaSources. DataFrame
101 is indexed on ``["diaObjectId"]``
102 diaSrcHistory : `pandas.DataFrame`
103 12 month history of DiaSources matched to the DiaObjects. Excludes
104 the newest DiaSource and is indexed on
105 ``["diaObjectId", "band", "diaSourceId"]``
106 diaForcedSources : `pandas.DataFrame`
107 12 month history of DiaForcedSources matched to the DiaObjects.
108 ``["diaObjectId"]``
109 diffIm : `lsst.afw.image.ExposureF`
110 Difference image the sources in ``diaSourceCat`` were detected in.
111 template : `lsst.afw.image.ExposureF` or `None`
112 Template image used to create the ``diffIm``.
113 """
114 alerts = []
115 self._patchDiaSources(diaSourceCat)
116 self._patchDiaSources(diaSrcHistory)
117 ccdVisitId = diffIm.info.id
118 diffImPhotoCalib = diffIm.getPhotoCalib()
119 templatePhotoCalib = template.getPhotoCalib()
120 for srcIndex, diaSource in diaSourceCat.iterrows():
121 # Get all diaSources for the associated diaObject.
122 # TODO: DM-31992 skip DiaSources associated with Solar System
123 # Objects for now.
124 if srcIndex[0] == 0:
125 continue
126 diaObject = diaObjectCat.loc[srcIndex[0]]
127 if diaObject["nDiaSources"] > 1:
128 objSourceHistory = diaSrcHistory.loc[srcIndex[0]]
129 else:
130 objSourceHistory = None
131 objDiaForcedSources = diaForcedSources.loc[srcIndex[0]]
132 sphPoint = geom.SpherePoint(diaSource["ra"],
133 diaSource["dec"],
134 geom.degrees)
136 cutoutExtent = self.createDiaSourceExtent(diaSource["bboxSize"])
137 diffImCutout = self.createCcdDataCutout(
138 diffIm,
139 sphPoint,
140 cutoutExtent,
141 diffImPhotoCalib,
142 diaSource["diaSourceId"])
143 templateCutout = self.createCcdDataCutout(
144 template,
145 sphPoint,
146 cutoutExtent,
147 templatePhotoCalib,
148 diaSource["diaSourceId"])
150 # TODO: Create alertIds DM-24858
151 alertId = diaSource["diaSourceId"]
152 alerts.append(
153 self.makeAlertDict(alertId,
154 diaSource,
155 diaObject,
156 objSourceHistory,
157 objDiaForcedSources,
158 diffImCutout,
159 templateCutout))
160 with open(os.path.join(self.config.alertWriteLocation,
161 f"{ccdVisitId}.avro"),
162 "wb") as f:
163 self.alertSchema.store_alerts(f, alerts)
165 def _patchDiaSources(self, diaSources):
166 """Add the ``programId`` column to the data.
168 Parameters
169 ----------
170 diaSources : `pandas.DataFrame`
171 DataFrame of DiaSources to patch.
172 """
173 diaSources["programId"] = 0
175 def createDiaSourceExtent(self, bboxSize):
176 """Create a extent for a box for the cutouts given the size of the
177 square BBox that covers the source footprint.
179 Parameters
180 ----------
181 bboxSize : `int`
182 Size of a side of the square bounding box in pixels.
184 Returns
185 -------
186 extent : `lsst.geom.Extent2I`
187 Geom object representing the size of the bounding box.
188 """
189 if bboxSize < self.config.minCutoutSize:
190 extent = geom.Extent2I(self.config.minCutoutSize,
191 self.config.minCutoutSize)
192 else:
193 extent = geom.Extent2I(bboxSize, bboxSize)
194 return extent
196 def createCcdDataCutout(self, image, skyCenter, extent, photoCalib, srcId):
197 """Grab an image as a cutout and return a calibrated CCDData image.
199 Parameters
200 ----------
201 image : `lsst.afw.image.ExposureF`
202 Image to pull cutout from.
203 skyCenter : `lsst.geom.SpherePoint`
204 Center point of DiaSource on the sky.
205 extent : `lsst.geom.Extent2I`
206 Bounding box to cutout from the image.
207 photoCalib : `lsst.afw.image.PhotoCalib`
208 Calibrate object of the image the cutout is cut from.
209 srcId : `int`
210 Unique id of DiaSource. Used for when an error occurs extracting
211 a cutout.
213 Returns
214 -------
215 ccdData : `astropy.nddata.CCDData` or `None`
216 CCDData object storing the calibrate information from the input
217 difference or template image.
218 """
219 # Catch errors in retrieving the cutout.
220 try:
221 cutout = image.getCutout(skyCenter, extent)
222 except InvalidParameterError:
223 point = image.getWcs().skyToPixel(skyCenter)
224 imBBox = image.getBBox()
225 if not geom.Box2D(image.getBBox()).contains(point):
226 self.log.warning(
227 "DiaSource id=%i centroid lies at pixel (%.2f, %.2f) "
228 "which is outside the Exposure with bounding box "
229 "((%i, %i), (%i, %i)). Returning None for cutout...",
230 srcId, point.x, point.y,
231 imBBox.minX, imBBox.maxX, imBBox.minY, imBBox.maxY)
232 else:
233 raise InvalidParameterError(
234 "Failed to retrieve cutout from image for DiaSource with "
235 "id=%i. InvalidParameterError thrown during cutout "
236 "creation. Exiting."
237 % srcId)
238 return None
240 # Find the value of the bottom corner of our cutout's BBox and
241 # subtract 1 so that the CCDData cutout position value will be
242 # [1, 1].
243 cutOutMinX = cutout.getBBox().minX - 1
244 cutOutMinY = cutout.getBBox().minY - 1
245 center = cutout.getWcs().skyToPixel(skyCenter)
246 calibCutout = photoCalib.calibrateImage(cutout.getMaskedImage())
248 cutoutWcs = wcs.WCS(naxis=2)
249 cutoutWcs.array_shape = (cutout.getBBox().getWidth(),
250 cutout.getBBox().getWidth())
251 cutoutWcs.wcs.crpix = [center.x - cutOutMinX, center.y - cutOutMinY]
252 cutoutWcs.wcs.crval = [skyCenter.getRa().asDegrees(),
253 skyCenter.getDec().asDegrees()]
254 cutoutWcs.wcs.cd = self.makeLocalTransformMatrix(cutout.getWcs(),
255 center,
256 skyCenter)
258 return CCDData(
259 data=calibCutout.getImage().array,
260 uncertainty=VarianceUncertainty(calibCutout.getVariance().array),
261 flags=calibCutout.getMask().array,
262 wcs=cutoutWcs,
263 meta={"cutMinX": cutOutMinX,
264 "cutMinY": cutOutMinY},
265 unit=u.nJy)
267 def makeLocalTransformMatrix(self, wcs, center, skyCenter):
268 """Create a local, linear approximation of the wcs transformation
269 matrix.
271 The approximation is created as if the center is at RA=0, DEC=0. All
272 comparing x,y coordinate are relative to the position of center. Matrix
273 is initially calculated with units arcseconds and then converted to
274 degrees. This yields higher precision results due to quirks in AST.
276 Parameters
277 ----------
278 wcs : `lsst.afw.geom.SkyWcs`
279 Wcs to approximate
280 center : `lsst.geom.Point2D`
281 Point at which to evaluate the LocalWcs.
282 skyCenter : `lsst.geom.SpherePoint`
283 Point on sky to approximate the Wcs.
285 Returns
286 -------
287 localMatrix : `numpy.ndarray`
288 Matrix representation the local wcs approximation with units
289 degrees.
290 """
291 blankCDMatrix = [[self._scale, 0], [0, self._scale]]
292 localGnomonicWcs = afwGeom.makeSkyWcs(
293 center, skyCenter, blankCDMatrix)
294 measurementToLocalGnomonic = wcs.getTransform().then(
295 localGnomonicWcs.getTransform().inverted()
296 )
297 localMatrix = measurementToLocalGnomonic.getJacobian(center)
298 return localMatrix / 3600
300 def makeAlertDict(self,
301 alertId,
302 diaSource,
303 diaObject,
304 objDiaSrcHistory,
305 objDiaForcedSources,
306 diffImCutout,
307 templateCutout):
308 """Convert data and package into a dictionary alert.
310 Parameters
311 ----------
312 diaSource : `pandas.DataFrame`
313 New single DiaSource to package.
314 diaObject : `pandas.DataFrame`
315 DiaObject that ``diaSource`` is matched to.
316 objDiaSrcHistory : `pandas.DataFrame`
317 12 month history of ``diaObject`` excluding the latest DiaSource.
318 objDiaForcedSources : `pandas.DataFrame`
319 12 month history of ``diaObject`` forced measurements.
320 diffImCutout : `astropy.nddata.CCDData` or `None`
321 Cutout of the difference image around the location of ``diaSource``
322 with a min size set by the ``cutoutSize`` configurable.
323 templateCutout : `astropy.nddata.CCDData` or `None`
324 Cutout of the template image around the location of ``diaSource``
325 with a min size set by the ``cutoutSize`` configurable.
326 """
327 alert = dict()
328 alert['alertId'] = alertId
329 alert['diaSource'] = diaSource.to_dict()
331 if objDiaSrcHistory is None:
332 alert['prvDiaSources'] = objDiaSrcHistory
333 else:
334 alert['prvDiaSources'] = objDiaSrcHistory.to_dict("records")
336 if isinstance(objDiaForcedSources, pd.Series):
337 alert['prvDiaForcedSources'] = [objDiaForcedSources.to_dict()]
338 else:
339 alert['prvDiaForcedSources'] = objDiaForcedSources.to_dict("records")
340 alert['prvDiaNondetectionLimits'] = None
342 alert['diaObject'] = diaObject.to_dict()
344 alert['ssObject'] = None
346 if diffImCutout is None:
347 alert['cutoutDifference'] = None
348 else:
349 alert['cutoutDifference'] = self.streamCcdDataToBytes(diffImCutout)
351 if templateCutout is None:
352 alert["cutoutTemplate"] = None
353 else:
354 alert["cutoutTemplate"] = self.streamCcdDataToBytes(templateCutout)
356 return alert
358 def streamCcdDataToBytes(self, cutout):
359 """Serialize a cutout into bytes.
361 Parameters
362 ----------
363 cutout : `astropy.nddata.CCDData`
364 Cutout to serialize.
366 Returns
367 -------
368 coutputBytes : `bytes`
369 Input cutout serialized into byte data.
370 """
371 with io.BytesIO() as streamer:
372 cutout.write(streamer, format="fits")
373 cutoutBytes = streamer.getvalue()
374 return cutoutBytes