Coverage for python/lsst/ap/association/packageAlerts.py: 24%
109 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-24 03:57 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-24 03:57 -0700
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ("PackageAlertsConfig", "PackageAlertsTask")
24import io
25import os
27from astropy import wcs
28import astropy.units as u
29from astropy.nddata import CCDData, VarianceUncertainty
30import pandas as pd
32import lsst.alert.packet as alertPack
33import lsst.afw.geom as afwGeom
34import lsst.geom as geom
35import lsst.pex.config as pexConfig
36from lsst.pex.exceptions import InvalidParameterError
37import lsst.pipe.base as pipeBase
38from lsst.utils.timer import timeMethod
41"""Methods for packaging Apdb and Pipelines data into Avro alerts.
42"""
45class PackageAlertsConfig(pexConfig.Config):
46 """Config class for AssociationTask.
47 """
48 schemaFile = pexConfig.Field(
49 dtype=str,
50 doc="Schema definition file for the avro alerts.",
51 default=alertPack.get_path_to_latest_schema()
52 )
53 minCutoutSize = pexConfig.RangeField(
54 dtype=int,
55 min=0,
56 max=1000,
57 default=30,
58 doc="Dimension of the square image cutouts to package in the alert."
59 )
60 alertWriteLocation = pexConfig.Field(
61 dtype=str,
62 doc="Location to write alerts to.",
63 default=os.path.join(os.getcwd(), "alerts"),
64 )
67class PackageAlertsTask(pipeBase.Task):
68 """Tasks for packaging Dia and Pipelines data into Avro alert packages.
69 """
70 ConfigClass = PackageAlertsConfig
71 _DefaultName = "packageAlerts"
73 _scale = (1.0 * geom.arcseconds).asDegrees()
75 def __init__(self, **kwargs):
76 super().__init__(**kwargs)
77 self.alertSchema = alertPack.Schema.from_file(self.config.schemaFile)
78 os.makedirs(self.config.alertWriteLocation, exist_ok=True)
80 @timeMethod
81 def run(self,
82 diaSourceCat,
83 diaObjectCat,
84 diaSrcHistory,
85 diaForcedSources,
86 diffIm,
87 template,
88 ccdExposureIdBits):
89 """Package DiaSources/Object and exposure data into Avro alerts.
91 Writes Avro alerts to a location determined by the
92 ``alertWriteLocation`` configurable.
94 Parameters
95 ----------
96 diaSourceCat : `pandas.DataFrame`
97 New DiaSources to package. DataFrame should be indexed on
98 ``["diaObjectId", "filterName", "diaSourceId"]``
99 diaObjectCat : `pandas.DataFrame`
100 New and updated DiaObjects matched to the new DiaSources. DataFrame
101 is indexed on ``["diaObjectId"]``
102 diaSrcHistory : `pandas.DataFrame`
103 12 month history of DiaSources matched to the DiaObjects. Excludes
104 the newest DiaSource and is indexed on
105 ``["diaObjectId", "filterName", "diaSourceId"]``
106 diaForcedSources : `pandas.DataFrame`
107 12 month history of DiaForcedSources matched to the DiaObjects.
108 ``["diaObjectId"]``
109 diffIm : `lsst.afw.image.ExposureF`
110 Difference image the sources in ``diaSourceCat`` were detected in.
111 template : `lsst.afw.image.ExposureF` or `None`
112 Template image used to create the ``diffIm``.
113 ccdExposureIdBits : `int`
114 Number of bits used in the ccdVisitId.
115 """
116 alerts = []
117 self._patchDiaSources(diaSourceCat)
118 self._patchDiaSources(diaSrcHistory)
119 ccdVisitId = diffIm.info.id
120 diffImPhotoCalib = diffIm.getPhotoCalib()
121 templatePhotoCalib = template.getPhotoCalib()
122 for srcIndex, diaSource in diaSourceCat.iterrows():
123 # Get all diaSources for the associated diaObject.
124 # TODO: DM-31992 skip DiaSources associated with Solar System
125 # Objects for now.
126 if srcIndex[0] == 0:
127 continue
128 diaObject = diaObjectCat.loc[srcIndex[0]]
129 if diaObject["nDiaSources"] > 1:
130 objSourceHistory = diaSrcHistory.loc[srcIndex[0]]
131 else:
132 objSourceHistory = None
133 objDiaForcedSources = diaForcedSources.loc[srcIndex[0]]
134 sphPoint = geom.SpherePoint(diaSource["ra"],
135 diaSource["decl"],
136 geom.degrees)
138 cutoutExtent = self.createDiaSourceExtent(diaSource["bboxSize"])
139 diffImCutout = self.createCcdDataCutout(
140 diffIm,
141 sphPoint,
142 cutoutExtent,
143 diffImPhotoCalib,
144 diaSource["diaSourceId"])
145 templateCutout = self.createCcdDataCutout(
146 template,
147 sphPoint,
148 cutoutExtent,
149 templatePhotoCalib,
150 diaSource["diaSourceId"])
152 # TODO: Create alertIds DM-24858
153 alertId = diaSource["diaSourceId"]
154 alerts.append(
155 self.makeAlertDict(alertId,
156 diaSource,
157 diaObject,
158 objSourceHistory,
159 objDiaForcedSources,
160 diffImCutout,
161 templateCutout))
162 with open(os.path.join(self.config.alertWriteLocation,
163 f"{ccdVisitId}.avro"),
164 "wb") as f:
165 self.alertSchema.store_alerts(f, alerts)
167 def _patchDiaSources(self, diaSources):
168 """Add the ``programId`` column to the data.
170 Parameters
171 ----------
172 diaSources : `pandas.DataFrame`
173 DataFrame of DiaSources to patch.
174 """
175 diaSources["programId"] = 0
177 def createDiaSourceExtent(self, bboxSize):
178 """Create a extent for a box for the cutouts given the size of the
179 square BBox that covers the source footprint.
181 Parameters
182 ----------
183 bboxSize : `int`
184 Size of a side of the square bounding box in pixels.
186 Returns
187 -------
188 extent : `lsst.geom.Extent2I`
189 Geom object representing the size of the bounding box.
190 """
191 if bboxSize < self.config.minCutoutSize:
192 extent = geom.Extent2I(self.config.minCutoutSize,
193 self.config.minCutoutSize)
194 else:
195 extent = geom.Extent2I(bboxSize, bboxSize)
196 return extent
198 def createCcdDataCutout(self, image, skyCenter, extent, photoCalib, srcId):
199 """Grab an image as a cutout and return a calibrated CCDData image.
201 Parameters
202 ----------
203 image : `lsst.afw.image.ExposureF`
204 Image to pull cutout from.
205 skyCenter : `lsst.geom.SpherePoint`
206 Center point of DiaSource on the sky.
207 extent : `lsst.geom.Extent2I`
208 Bounding box to cutout from the image.
209 photoCalib : `lsst.afw.image.PhotoCalib`
210 Calibrate object of the image the cutout is cut from.
211 srcId : `int`
212 Unique id of DiaSource. Used for when an error occurs extracting
213 a cutout.
215 Returns
216 -------
217 ccdData : `astropy.nddata.CCDData` or `None`
218 CCDData object storing the calibrate information from the input
219 difference or template image.
220 """
221 # Catch errors in retrieving the cutout.
222 try:
223 cutout = image.getCutout(skyCenter, extent)
224 except InvalidParameterError:
225 point = image.getWcs().skyToPixel(skyCenter)
226 imBBox = image.getBBox()
227 if not geom.Box2D(image.getBBox()).contains(point):
228 self.log.warning(
229 "DiaSource id=%i centroid lies at pixel (%.2f, %.2f) "
230 "which is outside the Exposure with bounding box "
231 "((%i, %i), (%i, %i)). Returning None for cutout...",
232 srcId, point.x, point.y,
233 imBBox.minX, imBBox.maxX, imBBox.minY, imBBox.maxY)
234 else:
235 raise InvalidParameterError(
236 "Failed to retrieve cutout from image for DiaSource with "
237 "id=%i. InvalidParameterError thrown during cutout "
238 "creation. Exiting."
239 % srcId)
240 return None
242 # Find the value of the bottom corner of our cutout's BBox and
243 # subtract 1 so that the CCDData cutout position value will be
244 # [1, 1].
245 cutOutMinX = cutout.getBBox().minX - 1
246 cutOutMinY = cutout.getBBox().minY - 1
247 center = cutout.getWcs().skyToPixel(skyCenter)
248 calibCutout = photoCalib.calibrateImage(cutout.getMaskedImage())
250 cutoutWcs = wcs.WCS(naxis=2)
251 cutoutWcs.array_shape = (cutout.getBBox().getWidth(),
252 cutout.getBBox().getWidth())
253 cutoutWcs.wcs.crpix = [center.x - cutOutMinX, center.y - cutOutMinY]
254 cutoutWcs.wcs.crval = [skyCenter.getRa().asDegrees(),
255 skyCenter.getDec().asDegrees()]
256 cutoutWcs.wcs.cd = self.makeLocalTransformMatrix(cutout.getWcs(),
257 center,
258 skyCenter)
260 return CCDData(
261 data=calibCutout.getImage().array,
262 uncertainty=VarianceUncertainty(calibCutout.getVariance().array),
263 flags=calibCutout.getMask().array,
264 wcs=cutoutWcs,
265 meta={"cutMinX": cutOutMinX,
266 "cutMinY": cutOutMinY},
267 unit=u.nJy)
269 def makeLocalTransformMatrix(self, wcs, center, skyCenter):
270 """Create a local, linear approximation of the wcs transformation
271 matrix.
273 The approximation is created as if the center is at RA=0, DEC=0. All
274 comparing x,y coordinate are relative to the position of center. Matrix
275 is initially calculated with units arcseconds and then converted to
276 degrees. This yields higher precision results due to quirks in AST.
278 Parameters
279 ----------
280 wcs : `lsst.afw.geom.SkyWcs`
281 Wcs to approximate
282 center : `lsst.geom.Point2D`
283 Point at which to evaluate the LocalWcs.
284 skyCenter : `lsst.geom.SpherePoint`
285 Point on sky to approximate the Wcs.
287 Returns
288 -------
289 localMatrix : `numpy.ndarray`
290 Matrix representation the local wcs approximation with units
291 degrees.
292 """
293 blankCDMatrix = [[self._scale, 0], [0, self._scale]]
294 localGnomonicWcs = afwGeom.makeSkyWcs(
295 center, skyCenter, blankCDMatrix)
296 measurementToLocalGnomonic = wcs.getTransform().then(
297 localGnomonicWcs.getTransform().inverted()
298 )
299 localMatrix = measurementToLocalGnomonic.getJacobian(center)
300 return localMatrix / 3600
302 def makeAlertDict(self,
303 alertId,
304 diaSource,
305 diaObject,
306 objDiaSrcHistory,
307 objDiaForcedSources,
308 diffImCutout,
309 templateCutout):
310 """Convert data and package into a dictionary alert.
312 Parameters
313 ----------
314 diaSource : `pandas.DataFrame`
315 New single DiaSource to package.
316 diaObject : `pandas.DataFrame`
317 DiaObject that ``diaSource`` is matched to.
318 objDiaSrcHistory : `pandas.DataFrame`
319 12 month history of ``diaObject`` excluding the latest DiaSource.
320 objDiaForcedSources : `pandas.DataFrame`
321 12 month history of ``diaObject`` forced measurements.
322 diffImCutout : `astropy.nddata.CCDData` or `None`
323 Cutout of the difference image around the location of ``diaSource``
324 with a min size set by the ``cutoutSize`` configurable.
325 templateCutout : `astropy.nddata.CCDData` or `None`
326 Cutout of the template image around the location of ``diaSource``
327 with a min size set by the ``cutoutSize`` configurable.
328 """
329 alert = dict()
330 alert['alertId'] = alertId
331 alert['diaSource'] = diaSource.to_dict()
333 if objDiaSrcHistory is None:
334 alert['prvDiaSources'] = objDiaSrcHistory
335 else:
336 alert['prvDiaSources'] = objDiaSrcHistory.to_dict("records")
338 if isinstance(objDiaForcedSources, pd.Series):
339 alert['prvDiaForcedSources'] = [objDiaForcedSources.to_dict()]
340 else:
341 alert['prvDiaForcedSources'] = objDiaForcedSources.to_dict("records")
342 alert['prvDiaNondetectionLimits'] = None
344 alert['diaObject'] = diaObject.to_dict()
346 alert['ssObject'] = None
348 if diffImCutout is None:
349 alert['cutoutDifference'] = None
350 else:
351 alert['cutoutDifference'] = self.streamCcdDataToBytes(diffImCutout)
353 if templateCutout is None:
354 alert["cutoutTemplate"] = None
355 else:
356 alert["cutoutTemplate"] = self.streamCcdDataToBytes(templateCutout)
358 return alert
360 def streamCcdDataToBytes(self, cutout):
361 """Serialize a cutout into bytes.
363 Parameters
364 ----------
365 cutout : `astropy.nddata.CCDData`
366 Cutout to serialize.
368 Returns
369 -------
370 coutputBytes : `bytes`
371 Input cutout serialized into byte data.
372 """
373 with io.BytesIO() as streamer:
374 cutout.write(streamer, format="fits")
375 cutoutBytes = streamer.getvalue()
376 return cutoutBytes