Coverage for tests/test_packageAlerts.py: 20%
149 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-05 12:50 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-05 12:50 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import io
23import os
24import numpy as np
25import pandas as pd
26import shutil
27import tempfile
28import unittest
30from astropy import wcs
31from astropy.nddata import CCDData
33from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask
34from lsst.afw.cameraGeom.testUtils import DetectorWrapper
35import lsst.afw.image as afwImage
36import lsst.daf.base as dafBase
37from lsst.dax.apdb import ApdbSql, ApdbSqlConfig
38import lsst.geom as geom
39import lsst.meas.base.tests
40from lsst.sphgeom import Box
41import lsst.utils.tests
42import utils_tests
45def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
46 """Run object and source catalogs through the Apdb to get the correct
47 table schemas.
49 Parameters
50 ----------
51 objects : `pandas.DataFrame`
52 Set of test DiaObjects to round trip.
53 sources : `pandas.DataFrame`
54 Set of test DiaSources to round trip.
55 forcedSources : `pandas.DataFrame`
56 Set of test DiaForcedSources to round trip.
57 dateTime : `lsst.daf.base.DateTime`
58 Time for the Apdb.
60 Returns
61 -------
62 objects : `pandas.DataFrame`
63 Round tripped objects.
64 sources : `pandas.DataFrame`
65 Round tripped sources.
66 """
67 tmpFile = tempfile.NamedTemporaryFile()
69 apdbConfig = ApdbSqlConfig()
70 apdbConfig.db_url = "sqlite:///" + tmpFile.name
71 apdbConfig.dia_object_index = "baseline"
72 apdbConfig.dia_object_columns = []
74 apdb = ApdbSql(config=apdbConfig)
75 apdb.makeSchema()
77 wholeSky = Box.full()
78 diaObjects = pd.concat([apdb.getDiaObjects(wholeSky), objects])
79 diaSources = pd.concat([apdb.getDiaSources(wholeSky, [], dateTime), sources])
80 diaForcedSources = pd.concat([apdb.getDiaForcedSources(wholeSky, [], dateTime), forcedSources])
82 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources)
84 diaObjects = apdb.getDiaObjects(wholeSky)
85 diaSources = apdb.getDiaSources(wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime)
86 diaForcedSources = apdb.getDiaForcedSources(
87 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime)
89 diaObjects.set_index("diaObjectId", drop=False, inplace=True)
90 diaSources.set_index(["diaObjectId", "band", "diaSourceId"],
91 drop=False,
92 inplace=True)
93 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True)
95 return (diaObjects, diaSources, diaForcedSources)
98class TestPackageAlerts(lsst.utils.tests.TestCase):
100 def setUp(self):
101 np.random.seed(1234)
102 self.cutoutSize = 35
103 self.center = lsst.geom.Point2D(50.1, 49.8)
104 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
105 lsst.geom.Extent2I(140, 160))
106 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
107 self.dataset.addSource(100000.0, self.center)
108 exposure, catalog = self.dataset.realize(
109 10.0,
110 self.dataset.makeMinimalSchema(),
111 randomSeed=0)
112 self.exposure = exposure
113 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector
114 self.exposure.setDetector(detector)
116 visit = afwImage.VisitInfo(
117 exposureTime=200.,
118 date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
119 dafBase.DateTime.Timescale.TAI))
120 self.exposure.info.id = 1234
121 self.exposure.info.setVisitInfo(visit)
123 self.exposure.setFilter(afwImage.FilterLabel(band='g', physical="g.MP9401"))
125 diaObjects = utils_tests.makeDiaObjects(2, self.exposure)
126 diaSourceHistory = utils_tests.makeDiaSources(10,
127 diaObjects["diaObjectId"],
128 self.exposure)
129 diaForcedSources = utils_tests.makeDiaForcedSources(10,
130 diaObjects["diaObjectId"],
131 self.exposure)
132 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb(
133 diaObjects,
134 diaSourceHistory,
135 diaForcedSources,
136 self.exposure.visitInfo.date)
137 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True)
138 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True)
139 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True)
140 diaSourceHistory["programId"] = 0
142 self.diaSources = diaSourceHistory.loc[
143 [(1, "g", 9), (2, "g", 10)], :]
144 self.diaSources["bboxSize"] = self.cutoutSize
145 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9),
146 (2, "g", 10)])
148 self.cutoutWcs = wcs.WCS(naxis=2)
149 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]]
150 self.cutoutWcs.wcs.crval = [
151 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(),
152 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()]
153 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
154 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
156 def testCreateExtent(self):
157 """Test the extent creation for the cutout bbox.
158 """
159 packConfig = PackageAlertsConfig()
160 # Just create a minimum less than the default cutout.
161 packConfig.minCutoutSize = self.cutoutSize - 5
162 packageAlerts = PackageAlertsTask(config=packConfig)
163 extent = packageAlerts.createDiaSourceExtent(
164 packConfig.minCutoutSize - 5)
165 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize,
166 packConfig.minCutoutSize))
167 # Test that the cutout size is correct.
168 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize)
169 self.assertTrue(extent == geom.Extent2I(self.cutoutSize,
170 self.cutoutSize))
172 def testCreateCcdDataCutout(self):
173 """Test that the data is being extracted into the CCDData cutout
174 correctly.
175 """
176 packageAlerts = PackageAlertsTask()
178 diaSrcId = 1234
179 ccdData = packageAlerts.createCcdDataCutout(
180 self.exposure,
181 self.exposure.getWcs().getSkyOrigin(),
182 self.exposure.getBBox().getDimensions(),
183 self.exposure.getPhotoCalib(),
184 diaSrcId)
185 calibExposure = self.exposure.getPhotoCalib().calibrateImage(
186 self.exposure.getMaskedImage())
188 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd,
189 self.cutoutWcs.wcs.cd)
190 self.assertFloatsAlmostEqual(ccdData.data,
191 calibExposure.getImage().array)
193 ccdData = packageAlerts.createCcdDataCutout(
194 self.exposure,
195 geom.SpherePoint(0, 0, geom.degrees),
196 self.exposure.getBBox().getDimensions(),
197 self.exposure.getPhotoCalib(),
198 diaSrcId)
199 self.assertTrue(ccdData is None)
201 def testMakeLocalTransformMatrix(self):
202 """Test that the local WCS approximation is correct.
203 """
204 packageAlerts = PackageAlertsTask()
206 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
207 cutout = self.exposure.getCutout(sphPoint,
208 geom.Extent2I(self.cutoutSize,
209 self.cutoutSize))
210 cd = packageAlerts.makeLocalTransformMatrix(
211 cutout.getWcs(), self.center, sphPoint)
212 self.assertFloatsAlmostEqual(
213 cd,
214 cutout.getWcs().getCdMatrix(),
215 rtol=1e-11,
216 atol=1e-11)
218 def testStreamCcdDataToBytes(self):
219 """Test round tripping an CCDData cutout to bytes and back.
220 """
221 packageAlerts = PackageAlertsTask()
223 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
224 cutout = self.exposure.getCutout(sphPoint,
225 geom.Extent2I(self.cutoutSize,
226 self.cutoutSize))
227 cutoutCcdData = CCDData(
228 data=cutout.getImage().array,
229 wcs=self.cutoutWcs,
230 unit="adu")
232 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData)
233 with io.BytesIO(cutoutBytes) as bytesIO:
234 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
235 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data)
237 def testMakeAlertDict(self):
238 """Test stripping data from the various data products and into a
239 dictionary "alert".
240 """
241 packageAlerts = PackageAlertsTask()
242 alertId = 1234
244 for srcIdx, diaSource in self.diaSources.iterrows():
245 sphPoint = geom.SpherePoint(diaSource["ra"],
246 diaSource["dec"],
247 geom.degrees)
248 cutout = self.exposure.getCutout(sphPoint,
249 geom.Extent2I(self.cutoutSize,
250 self.cutoutSize))
251 ccdCutout = packageAlerts.createCcdDataCutout(
252 cutout,
253 sphPoint,
254 geom.Extent2I(self.cutoutSize, self.cutoutSize),
255 cutout.getPhotoCalib(),
256 1234)
257 cutoutBytes = packageAlerts.streamCcdDataToBytes(
258 ccdCutout)
259 objSources = self.diaSourceHistory.loc[srcIdx[0]]
260 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
261 alert = packageAlerts.makeAlertDict(
262 alertId,
263 diaSource,
264 self.diaObjects.loc[srcIdx[0]],
265 objSources,
266 objForcedSources,
267 ccdCutout,
268 ccdCutout)
269 self.assertEqual(len(alert), 9)
271 self.assertEqual(alert["alertId"], alertId)
272 self.assertEqual(alert["diaSource"], diaSource.to_dict())
273 self.assertEqual(alert["cutoutDifference"],
274 cutoutBytes)
275 self.assertEqual(alert["cutoutTemplate"],
276 cutoutBytes)
278 def testRun(self):
279 """Test the run method of package alerts.
280 """
281 packConfig = PackageAlertsConfig()
282 tempdir = tempfile.mkdtemp(prefix='alerts')
283 packConfig.alertWriteLocation = tempdir
284 packageAlerts = PackageAlertsTask(config=packConfig)
286 packageAlerts.run(self.diaSources,
287 self.diaObjects,
288 self.diaSourceHistory,
289 self.diaForcedSources,
290 self.exposure,
291 self.exposure,
292 None)
294 ccdVisitId = self.exposure.info.id
295 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f:
296 writer_schema, data_stream = \
297 packageAlerts.alertSchema.retrieve_alerts(f)
298 data = list(data_stream)
299 self.assertEqual(len(data), len(self.diaSources))
300 for idx, alert in enumerate(data):
301 for key, value in alert["diaSource"].items():
302 if isinstance(value, float):
303 if np.isnan(self.diaSources.iloc[idx][key]):
304 self.assertTrue(np.isnan(value))
305 else:
306 self.assertAlmostEqual(
307 1 - value / self.diaSources.iloc[idx][key],
308 0.)
309 else:
310 self.assertEqual(value, self.diaSources.iloc[idx][key])
311 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
312 alert["diaSource"]["dec"],
313 geom.degrees)
314 cutout = self.exposure.getCutout(sphPoint,
315 geom.Extent2I(self.cutoutSize,
316 self.cutoutSize))
317 ccdCutout = packageAlerts.createCcdDataCutout(
318 cutout,
319 sphPoint,
320 geom.Extent2I(self.cutoutSize, self.cutoutSize),
321 cutout.getPhotoCalib(),
322 1234)
323 self.assertEqual(alert["cutoutDifference"],
324 packageAlerts.streamCcdDataToBytes(ccdCutout))
326 shutil.rmtree(tempdir)
329class MemoryTester(lsst.utils.tests.MemoryTestCase):
330 pass
333def setup_module(module):
334 lsst.utils.tests.init()
337if __name__ == "__main__": 337 ↛ 338line 337 didn't jump to line 338, because the condition on line 337 was never true
338 lsst.utils.tests.init()
339 unittest.main()