Coverage for tests/test_packageAlerts.py: 17%
183 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-05 04:23 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-05 04:23 -0700
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import io
23import os
24import numpy as np
25import pandas as pd
26import shutil
27import tempfile
28import unittest
30from astropy import wcs
31from astropy.nddata import CCDData
33from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask
34from lsst.afw.cameraGeom.testUtils import DetectorWrapper
35import lsst.afw.image as afwImage
36import lsst.daf.base as dafBase
37from lsst.dax.apdb import ApdbSql, ApdbSqlConfig
38import lsst.geom as geom
39import lsst.meas.base.tests
40from lsst.sphgeom import Box
41import lsst.utils.tests
44def makeDiaObjects(nObjects, exposure):
45 """Make a test set of DiaObjects.
47 Parameters
48 ----------
49 nObjects : `int`
50 Number of objects to create.
51 exposure : `lsst.afw.image.Exposure`
52 Exposure to create objects over.
54 Returns
55 -------
56 diaObjects : `pandas.DataFrame`
57 DiaObjects generated across the exposure.
58 """
59 bbox = geom.Box2D(exposure.getBBox())
60 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects)
61 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects)
63 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
64 system=dafBase.DateTime.MJD)
66 wcs = exposure.getWcs()
68 data = []
69 for idx, (x, y) in enumerate(zip(rand_x, rand_y)):
70 coord = wcs.pixelToSky(x, y)
71 newObject = {"ra": coord.getRa().asDegrees(),
72 "decl": coord.getDec().asDegrees(),
73 "radecTai": midPointTaiMJD,
74 "diaObjectId": idx + 1,
75 "pmParallaxNdata": 0,
76 "nearbyObj1": 0,
77 "nearbyObj2": 0,
78 "nearbyObj3": 0,
79 "flags": 1,
80 "nDiaSources": 5}
81 for f in ["u", "g", "r", "i", "z", "y"]:
82 newObject["%sPSFluxNdata" % f] = 0
83 data.append(newObject)
85 return pd.DataFrame(data=data)
88def makeDiaSources(nSources, diaObjectIds, exposure):
89 """Make a test set of DiaSources.
91 Parameters
92 ----------
93 nSources : `int`
94 Number of sources to create.
95 diaObjectIds : `numpy.ndarray`
96 Integer Ids of diaobjects to "associate" with the DiaSources.
97 exposure : `lsst.afw.image.Exposure`
98 Exposure to create sources over.
100 Returns
101 -------
102 diaSources : `pandas.DataFrame`
103 DiaSources generated across the exposure.
104 """
105 bbox = geom.Box2D(exposure.getBBox())
106 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources)
107 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources)
109 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
110 system=dafBase.DateTime.MJD)
112 wcs = exposure.getWcs()
113 ccdVisitId = exposure.info.id
115 data = []
116 for idx, (x, y) in enumerate(zip(rand_x, rand_y)):
117 coord = wcs.pixelToSky(x, y)
118 objId = diaObjectIds[idx % len(diaObjectIds)]
119 # Put together the minimum values for the alert.
120 data.append({"ra": coord.getRa().asDegrees(),
121 "decl": coord.getDec().asDegrees(),
122 "x": x,
123 "y": y,
124 "ccdVisitId": ccdVisitId,
125 "diaObjectId": objId,
126 "ssObjectId": 0,
127 "parentDiaSourceId": 0,
128 "prv_procOrder": 0,
129 "diaSourceId": idx + 1,
130 "midPointTai": midPointTaiMJD + 1.0 * idx,
131 "filterName": exposure.getFilter().bandLabel,
132 "psNdata": 0,
133 "trailNdata": 0,
134 "dipNdata": 0,
135 "flags": 1})
137 return pd.DataFrame(data=data)
140def makeDiaForcedSources(nSources, diaObjectIds, exposure):
141 """Make a test set of DiaSources.
143 Parameters
144 ----------
145 nSources : `int`
146 Number of sources to create.
147 diaObjectIds : `numpy.ndarray`
148 Integer Ids of diaobjects to "associate" with the DiaSources.
149 exposure : `lsst.afw.image.Exposure`
150 Exposure to create sources over.
152 Returns
153 -------
154 diaSources : `pandas.DataFrame`
155 DiaSources generated across the exposure.
156 """
157 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
158 system=dafBase.DateTime.MJD)
160 ccdVisitId = exposure.info.id
162 data = []
163 for idx in range(nSources):
164 objId = diaObjectIds[idx % len(diaObjectIds)]
165 # Put together the minimum values for the alert.
166 data.append({"diaForcedSourceId": idx + 1,
167 "ccdVisitId": ccdVisitId + idx,
168 "diaObjectId": objId,
169 "midPointTai": midPointTaiMJD + 1.0 * idx,
170 "filterName": exposure.getFilter().bandLabel,
171 "flags": 0})
173 return pd.DataFrame(data=data)
176def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
177 """Run object and source catalogs through the Apdb to get the correct
178 table schemas.
180 Parameters
181 ----------
182 objects : `pandas.DataFrame`
183 Set of test DiaObjects to round trip.
184 sources : `pandas.DataFrame`
185 Set of test DiaSources to round trip.
186 forcedSources : `pandas.DataFrame`
187 Set of test DiaForcedSources to round trip.
188 dateTime : `lsst.daf.base.DateTime`
189 Time for the Apdb.
191 Returns
192 -------
193 objects : `pandas.DataFrame`
194 Round tripped objects.
195 sources : `pandas.DataFrame`
196 Round tripped sources.
197 """
198 tmpFile = tempfile.NamedTemporaryFile()
200 apdbConfig = ApdbSqlConfig()
201 apdbConfig.db_url = "sqlite:///" + tmpFile.name
202 apdbConfig.dia_object_index = "baseline"
203 apdbConfig.dia_object_columns = []
205 apdb = ApdbSql(config=apdbConfig)
206 apdb.makeSchema()
208 wholeSky = Box.full()
209 diaObjects = pd.concat([apdb.getDiaObjects(wholeSky), objects])
210 diaSources = pd.concat([apdb.getDiaSources(wholeSky, [], dateTime), sources])
211 diaForcedSources = pd.concat([apdb.getDiaForcedSources(wholeSky, [], dateTime), forcedSources])
213 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources)
215 diaObjects = apdb.getDiaObjects(wholeSky)
216 diaSources = apdb.getDiaSources(wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime)
217 diaForcedSources = apdb.getDiaForcedSources(
218 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime)
220 diaObjects.set_index("diaObjectId", drop=False, inplace=True)
221 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"],
222 drop=False,
223 inplace=True)
224 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True)
226 return (diaObjects, diaSources, diaForcedSources)
229class TestPackageAlerts(lsst.utils.tests.TestCase):
231 def setUp(self):
232 np.random.seed(1234)
233 self.cutoutSize = 35
234 self.center = lsst.geom.Point2D(50.1, 49.8)
235 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
236 lsst.geom.Extent2I(140, 160))
237 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
238 self.dataset.addSource(100000.0, self.center)
239 exposure, catalog = self.dataset.realize(
240 10.0,
241 self.dataset.makeMinimalSchema(),
242 randomSeed=0)
243 self.exposure = exposure
244 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector
245 self.exposure.setDetector(detector)
247 visit = afwImage.VisitInfo(
248 exposureId=1234,
249 exposureTime=200.,
250 date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
251 dafBase.DateTime.Timescale.TAI))
252 self.exposure.info.id = 1234
253 self.exposure.getInfo().setVisitInfo(visit)
255 self.exposure.setFilter(afwImage.FilterLabel(band='g', physical="g.MP9401"))
257 diaObjects = makeDiaObjects(2, self.exposure)
258 diaSourceHistory = makeDiaSources(10,
259 diaObjects["diaObjectId"],
260 self.exposure)
261 diaForcedSources = makeDiaForcedSources(10,
262 diaObjects["diaObjectId"],
263 self.exposure)
264 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb(
265 diaObjects,
266 diaSourceHistory,
267 diaForcedSources,
268 self.exposure.getInfo().getVisitInfo().getDate())
269 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True)
270 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True)
271 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True)
272 diaSourceHistory["programId"] = 0
274 self.diaSources = diaSourceHistory.loc[
275 [(1, "g", 9), (2, "g", 10)], :]
276 self.diaSources["bboxSize"] = self.cutoutSize
277 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9),
278 (2, "g", 10)])
280 self.cutoutWcs = wcs.WCS(naxis=2)
281 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]]
282 self.cutoutWcs.wcs.crval = [
283 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(),
284 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()]
285 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
286 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
288 def testCreateExtent(self):
289 """Test the extent creation for the cutout bbox.
290 """
291 packConfig = PackageAlertsConfig()
292 # Just create a minimum less than the default cutout.
293 packConfig.minCutoutSize = self.cutoutSize - 5
294 packageAlerts = PackageAlertsTask(config=packConfig)
295 extent = packageAlerts.createDiaSourceExtent(
296 packConfig.minCutoutSize - 5)
297 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize,
298 packConfig.minCutoutSize))
299 # Test that the cutout size is correct.
300 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize)
301 self.assertTrue(extent == geom.Extent2I(self.cutoutSize,
302 self.cutoutSize))
304 def testCreateCcdDataCutout(self):
305 """Test that the data is being extracted into the CCDData cutout
306 correctly.
307 """
308 packageAlerts = PackageAlertsTask()
310 diaSrcId = 1234
311 ccdData = packageAlerts.createCcdDataCutout(
312 self.exposure,
313 self.exposure.getWcs().getSkyOrigin(),
314 self.exposure.getBBox().getDimensions(),
315 self.exposure.getPhotoCalib(),
316 diaSrcId)
317 calibExposure = self.exposure.getPhotoCalib().calibrateImage(
318 self.exposure.getMaskedImage())
320 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd,
321 self.cutoutWcs.wcs.cd)
322 self.assertFloatsAlmostEqual(ccdData.data,
323 calibExposure.getImage().array)
325 ccdData = packageAlerts.createCcdDataCutout(
326 self.exposure,
327 geom.SpherePoint(0, 0, geom.degrees),
328 self.exposure.getBBox().getDimensions(),
329 self.exposure.getPhotoCalib(),
330 diaSrcId)
331 self.assertTrue(ccdData is None)
333 def testMakeLocalTransformMatrix(self):
334 """Test that the local WCS approximation is correct.
335 """
336 packageAlerts = PackageAlertsTask()
338 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
339 cutout = self.exposure.getCutout(sphPoint,
340 geom.Extent2I(self.cutoutSize,
341 self.cutoutSize))
342 cd = packageAlerts.makeLocalTransformMatrix(
343 cutout.getWcs(), self.center, sphPoint)
344 self.assertFloatsAlmostEqual(
345 cd,
346 cutout.getWcs().getCdMatrix(),
347 rtol=1e-11,
348 atol=1e-11)
350 def testStreamCcdDataToBytes(self):
351 """Test round tripping an CCDData cutout to bytes and back.
352 """
353 packageAlerts = PackageAlertsTask()
355 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
356 cutout = self.exposure.getCutout(sphPoint,
357 geom.Extent2I(self.cutoutSize,
358 self.cutoutSize))
359 cutoutCcdData = CCDData(
360 data=cutout.getImage().array,
361 wcs=self.cutoutWcs,
362 unit="adu")
364 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData)
365 with io.BytesIO(cutoutBytes) as bytesIO:
366 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
367 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data)
369 def testMakeAlertDict(self):
370 """Test stripping data from the various data products and into a
371 dictionary "alert".
372 """
373 packageAlerts = PackageAlertsTask()
374 alertId = 1234
376 for srcIdx, diaSource in self.diaSources.iterrows():
377 sphPoint = geom.SpherePoint(diaSource["ra"],
378 diaSource["decl"],
379 geom.degrees)
380 cutout = self.exposure.getCutout(sphPoint,
381 geom.Extent2I(self.cutoutSize,
382 self.cutoutSize))
383 ccdCutout = packageAlerts.createCcdDataCutout(
384 cutout,
385 sphPoint,
386 geom.Extent2I(self.cutoutSize, self.cutoutSize),
387 cutout.getPhotoCalib(),
388 1234)
389 cutoutBytes = packageAlerts.streamCcdDataToBytes(
390 ccdCutout)
391 objSources = self.diaSourceHistory.loc[srcIdx[0]]
392 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
393 alert = packageAlerts.makeAlertDict(
394 alertId,
395 diaSource,
396 self.diaObjects.loc[srcIdx[0]],
397 objSources,
398 objForcedSources,
399 ccdCutout,
400 ccdCutout)
401 self.assertEqual(len(alert), 9)
403 self.assertEqual(alert["alertId"], alertId)
404 self.assertEqual(alert["diaSource"], diaSource.to_dict())
405 self.assertEqual(alert["cutoutDifference"],
406 cutoutBytes)
407 self.assertEqual(alert["cutoutTemplate"],
408 cutoutBytes)
410 def testRun(self):
411 """Test the run method of package alerts.
412 """
413 packConfig = PackageAlertsConfig()
414 tempdir = tempfile.mkdtemp(prefix='alerts')
415 packConfig.alertWriteLocation = tempdir
416 packageAlerts = PackageAlertsTask(config=packConfig)
418 packageAlerts.run(self.diaSources,
419 self.diaObjects,
420 self.diaSourceHistory,
421 self.diaForcedSources,
422 self.exposure,
423 self.exposure,
424 None)
426 ccdVisitId = self.exposure.info.id
427 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f:
428 writer_schema, data_stream = \
429 packageAlerts.alertSchema.retrieve_alerts(f)
430 data = list(data_stream)
431 self.assertEqual(len(data), len(self.diaSources))
432 for idx, alert in enumerate(data):
433 for key, value in alert["diaSource"].items():
434 if isinstance(value, float):
435 if np.isnan(self.diaSources.iloc[idx][key]):
436 self.assertTrue(np.isnan(value))
437 else:
438 self.assertAlmostEqual(
439 1 - value / self.diaSources.iloc[idx][key],
440 0.)
441 else:
442 self.assertEqual(value, self.diaSources.iloc[idx][key])
443 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
444 alert["diaSource"]["decl"],
445 geom.degrees)
446 cutout = self.exposure.getCutout(sphPoint,
447 geom.Extent2I(self.cutoutSize,
448 self.cutoutSize))
449 ccdCutout = packageAlerts.createCcdDataCutout(
450 cutout,
451 sphPoint,
452 geom.Extent2I(self.cutoutSize, self.cutoutSize),
453 cutout.getPhotoCalib(),
454 1234)
455 self.assertEqual(alert["cutoutDifference"],
456 packageAlerts.streamCcdDataToBytes(ccdCutout))
458 shutil.rmtree(tempdir)
461class MemoryTester(lsst.utils.tests.MemoryTestCase):
462 pass
465def setup_module(module):
466 lsst.utils.tests.init()
469if __name__ == "__main__": 469 ↛ 470line 469 didn't jump to line 470, because the condition on line 469 was never true
470 lsst.utils.tests.init()
471 unittest.main()