Coverage for tests/test_packageAlerts.py: 20%
149 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-30 06:06 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-30 06:06 -0700
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import io
23import os
24import numpy as np
25import pandas as pd
26import shutil
27import tempfile
28import unittest
30from astropy import wcs
31from astropy.nddata import CCDData
33from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask
34from lsst.afw.cameraGeom.testUtils import DetectorWrapper
35import lsst.afw.image as afwImage
36import lsst.daf.base as dafBase
37from lsst.dax.apdb import ApdbSql, ApdbSqlConfig
38import lsst.geom as geom
39import lsst.meas.base.tests
40from lsst.sphgeom import Box
41import lsst.utils.tests
42import utils_tests
45def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
46 """Run object and source catalogs through the Apdb to get the correct
47 table schemas.
49 Parameters
50 ----------
51 objects : `pandas.DataFrame`
52 Set of test DiaObjects to round trip.
53 sources : `pandas.DataFrame`
54 Set of test DiaSources to round trip.
55 forcedSources : `pandas.DataFrame`
56 Set of test DiaForcedSources to round trip.
57 dateTime : `lsst.daf.base.DateTime`
58 Time for the Apdb.
60 Returns
61 -------
62 objects : `pandas.DataFrame`
63 Round tripped objects.
64 sources : `pandas.DataFrame`
65 Round tripped sources.
66 """
67 tmpFile = tempfile.NamedTemporaryFile()
69 apdbConfig = ApdbSqlConfig()
70 apdbConfig.db_url = "sqlite:///" + tmpFile.name
71 apdbConfig.dia_object_index = "baseline"
72 apdbConfig.dia_object_columns = []
74 apdb = ApdbSql(config=apdbConfig)
75 apdb.makeSchema()
77 wholeSky = Box.full()
78 diaObjects = pd.concat([apdb.getDiaObjects(wholeSky), objects])
79 diaSources = pd.concat([apdb.getDiaSources(wholeSky, [], dateTime), sources])
80 diaForcedSources = pd.concat([apdb.getDiaForcedSources(wholeSky, [], dateTime), forcedSources])
82 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources)
84 diaObjects = apdb.getDiaObjects(wholeSky)
85 diaSources = apdb.getDiaSources(wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime)
86 diaForcedSources = apdb.getDiaForcedSources(
87 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime)
89 diaObjects.set_index("diaObjectId", drop=False, inplace=True)
90 diaSources.set_index(["diaObjectId", "band", "diaSourceId"],
91 drop=False,
92 inplace=True)
93 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True)
95 return (diaObjects, diaSources, diaForcedSources)
98class TestPackageAlerts(lsst.utils.tests.TestCase):
100 def setUp(self):
101 np.random.seed(1234)
102 self.cutoutSize = 35
103 self.center = lsst.geom.Point2D(50.1, 49.8)
104 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
105 lsst.geom.Extent2I(140, 160))
106 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
107 self.dataset.addSource(100000.0, self.center)
108 exposure, catalog = self.dataset.realize(
109 10.0,
110 self.dataset.makeMinimalSchema(),
111 randomSeed=0)
112 self.exposure = exposure
113 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector
114 self.exposure.setDetector(detector)
116 visit = afwImage.VisitInfo(
117 exposureId=1234,
118 exposureTime=200.,
119 date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
120 dafBase.DateTime.Timescale.TAI))
121 self.exposure.info.id = 1234
122 self.exposure.info.setVisitInfo(visit)
124 self.exposure.setFilter(afwImage.FilterLabel(band='g', physical="g.MP9401"))
126 diaObjects = utils_tests.makeDiaObjects(2, self.exposure)
127 diaSourceHistory = utils_tests.makeDiaSources(10,
128 diaObjects["diaObjectId"],
129 self.exposure)
130 diaForcedSources = utils_tests.makeDiaForcedSources(10,
131 diaObjects["diaObjectId"],
132 self.exposure)
133 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb(
134 diaObjects,
135 diaSourceHistory,
136 diaForcedSources,
137 self.exposure.visitInfo.date)
138 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True)
139 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True)
140 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True)
141 diaSourceHistory["programId"] = 0
143 self.diaSources = diaSourceHistory.loc[
144 [(1, "g", 9), (2, "g", 10)], :]
145 self.diaSources["bboxSize"] = self.cutoutSize
146 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9),
147 (2, "g", 10)])
149 self.cutoutWcs = wcs.WCS(naxis=2)
150 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]]
151 self.cutoutWcs.wcs.crval = [
152 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(),
153 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()]
154 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
155 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
157 def testCreateExtent(self):
158 """Test the extent creation for the cutout bbox.
159 """
160 packConfig = PackageAlertsConfig()
161 # Just create a minimum less than the default cutout.
162 packConfig.minCutoutSize = self.cutoutSize - 5
163 packageAlerts = PackageAlertsTask(config=packConfig)
164 extent = packageAlerts.createDiaSourceExtent(
165 packConfig.minCutoutSize - 5)
166 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize,
167 packConfig.minCutoutSize))
168 # Test that the cutout size is correct.
169 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize)
170 self.assertTrue(extent == geom.Extent2I(self.cutoutSize,
171 self.cutoutSize))
173 def testCreateCcdDataCutout(self):
174 """Test that the data is being extracted into the CCDData cutout
175 correctly.
176 """
177 packageAlerts = PackageAlertsTask()
179 diaSrcId = 1234
180 ccdData = packageAlerts.createCcdDataCutout(
181 self.exposure,
182 self.exposure.getWcs().getSkyOrigin(),
183 self.exposure.getBBox().getDimensions(),
184 self.exposure.getPhotoCalib(),
185 diaSrcId)
186 calibExposure = self.exposure.getPhotoCalib().calibrateImage(
187 self.exposure.getMaskedImage())
189 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd,
190 self.cutoutWcs.wcs.cd)
191 self.assertFloatsAlmostEqual(ccdData.data,
192 calibExposure.getImage().array)
194 ccdData = packageAlerts.createCcdDataCutout(
195 self.exposure,
196 geom.SpherePoint(0, 0, geom.degrees),
197 self.exposure.getBBox().getDimensions(),
198 self.exposure.getPhotoCalib(),
199 diaSrcId)
200 self.assertTrue(ccdData is None)
202 def testMakeLocalTransformMatrix(self):
203 """Test that the local WCS approximation is correct.
204 """
205 packageAlerts = PackageAlertsTask()
207 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
208 cutout = self.exposure.getCutout(sphPoint,
209 geom.Extent2I(self.cutoutSize,
210 self.cutoutSize))
211 cd = packageAlerts.makeLocalTransformMatrix(
212 cutout.getWcs(), self.center, sphPoint)
213 self.assertFloatsAlmostEqual(
214 cd,
215 cutout.getWcs().getCdMatrix(),
216 rtol=1e-11,
217 atol=1e-11)
219 def testStreamCcdDataToBytes(self):
220 """Test round tripping an CCDData cutout to bytes and back.
221 """
222 packageAlerts = PackageAlertsTask()
224 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
225 cutout = self.exposure.getCutout(sphPoint,
226 geom.Extent2I(self.cutoutSize,
227 self.cutoutSize))
228 cutoutCcdData = CCDData(
229 data=cutout.getImage().array,
230 wcs=self.cutoutWcs,
231 unit="adu")
233 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData)
234 with io.BytesIO(cutoutBytes) as bytesIO:
235 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
236 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data)
238 def testMakeAlertDict(self):
239 """Test stripping data from the various data products and into a
240 dictionary "alert".
241 """
242 packageAlerts = PackageAlertsTask()
243 alertId = 1234
245 for srcIdx, diaSource in self.diaSources.iterrows():
246 sphPoint = geom.SpherePoint(diaSource["ra"],
247 diaSource["dec"],
248 geom.degrees)
249 cutout = self.exposure.getCutout(sphPoint,
250 geom.Extent2I(self.cutoutSize,
251 self.cutoutSize))
252 ccdCutout = packageAlerts.createCcdDataCutout(
253 cutout,
254 sphPoint,
255 geom.Extent2I(self.cutoutSize, self.cutoutSize),
256 cutout.getPhotoCalib(),
257 1234)
258 cutoutBytes = packageAlerts.streamCcdDataToBytes(
259 ccdCutout)
260 objSources = self.diaSourceHistory.loc[srcIdx[0]]
261 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
262 alert = packageAlerts.makeAlertDict(
263 alertId,
264 diaSource,
265 self.diaObjects.loc[srcIdx[0]],
266 objSources,
267 objForcedSources,
268 ccdCutout,
269 ccdCutout)
270 self.assertEqual(len(alert), 9)
272 self.assertEqual(alert["alertId"], alertId)
273 self.assertEqual(alert["diaSource"], diaSource.to_dict())
274 self.assertEqual(alert["cutoutDifference"],
275 cutoutBytes)
276 self.assertEqual(alert["cutoutTemplate"],
277 cutoutBytes)
279 def testRun(self):
280 """Test the run method of package alerts.
281 """
282 packConfig = PackageAlertsConfig()
283 tempdir = tempfile.mkdtemp(prefix='alerts')
284 packConfig.alertWriteLocation = tempdir
285 packageAlerts = PackageAlertsTask(config=packConfig)
287 packageAlerts.run(self.diaSources,
288 self.diaObjects,
289 self.diaSourceHistory,
290 self.diaForcedSources,
291 self.exposure,
292 self.exposure,
293 None)
295 ccdVisitId = self.exposure.info.id
296 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f:
297 writer_schema, data_stream = \
298 packageAlerts.alertSchema.retrieve_alerts(f)
299 data = list(data_stream)
300 self.assertEqual(len(data), len(self.diaSources))
301 for idx, alert in enumerate(data):
302 for key, value in alert["diaSource"].items():
303 if isinstance(value, float):
304 if np.isnan(self.diaSources.iloc[idx][key]):
305 self.assertTrue(np.isnan(value))
306 else:
307 self.assertAlmostEqual(
308 1 - value / self.diaSources.iloc[idx][key],
309 0.)
310 else:
311 self.assertEqual(value, self.diaSources.iloc[idx][key])
312 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
313 alert["diaSource"]["dec"],
314 geom.degrees)
315 cutout = self.exposure.getCutout(sphPoint,
316 geom.Extent2I(self.cutoutSize,
317 self.cutoutSize))
318 ccdCutout = packageAlerts.createCcdDataCutout(
319 cutout,
320 sphPoint,
321 geom.Extent2I(self.cutoutSize, self.cutoutSize),
322 cutout.getPhotoCalib(),
323 1234)
324 self.assertEqual(alert["cutoutDifference"],
325 packageAlerts.streamCcdDataToBytes(ccdCutout))
327 shutil.rmtree(tempdir)
330class MemoryTester(lsst.utils.tests.MemoryTestCase):
331 pass
334def setup_module(module):
335 lsst.utils.tests.init()
338if __name__ == "__main__": 338 ↛ 339line 338 didn't jump to line 339, because the condition on line 338 was never true
339 lsst.utils.tests.init()
340 unittest.main()