Coverage for tests/test_packageAlerts.py : 18%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import io
23import os
24import numpy as np
25import pandas as pd
26import shutil
27import tempfile
28import unittest
30from astropy import wcs
31from astropy.nddata import CCDData
33from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask
34from lsst.afw.cameraGeom.testUtils import DetectorWrapper
35import lsst.afw.image as afwImage
36import lsst.daf.base as dafBase
37from lsst.dax.apdb import ApdbSql, ApdbSqlConfig
38import lsst.geom as geom
39import lsst.meas.base.tests
40from lsst.sphgeom import Box
41from lsst.utils import getPackageDir
42import lsst.utils.tests
45def _data_file_name(basename, module_name):
46 """Return path name of a data file.
48 Parameters
49 ----------
50 basename : `str`
51 Name of the file to add to the path string.
52 module_name : `str`
53 Name of lsst stack package environment variable.
55 Returns
56 -------
57 data_file_path : `str`
58 Full path of the file to load from the "data" directory in a given
59 repository.
60 """
61 return os.path.join(getPackageDir(module_name), "data", basename)
64def makeDiaObjects(nObjects, exposure):
65 """Make a test set of DiaObjects.
67 Parameters
68 ----------
69 nObjects : `int`
70 Number of objects to create.
71 exposure : `lsst.afw.image.Exposure`
72 Exposure to create objects over.
74 Returns
75 -------
76 diaObjects : `pandas.DataFrame`
77 DiaObjects generated across the exposure.
78 """
79 bbox = geom.Box2D(exposure.getBBox())
80 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects)
81 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects)
83 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
84 system=dafBase.DateTime.MJD)
86 wcs = exposure.getWcs()
88 data = []
89 for idx, (x, y) in enumerate(zip(rand_x, rand_y)):
90 coord = wcs.pixelToSky(x, y)
91 newObject = {"ra": coord.getRa().asDegrees(),
92 "decl": coord.getDec().asDegrees(),
93 "radecTai": midPointTaiMJD,
94 "diaObjectId": idx + 1,
95 "pmParallaxNdata": 0,
96 "nearbyObj1": 0,
97 "nearbyObj2": 0,
98 "nearbyObj3": 0,
99 "flags": 1,
100 "nDiaSources": 5}
101 for f in ["u", "g", "r", "i", "z", "y"]:
102 newObject["%sPSFluxNdata" % f] = 0
103 data.append(newObject)
105 return pd.DataFrame(data=data)
108def makeDiaSources(nSources, diaObjectIds, exposure):
109 """Make a test set of DiaSources.
111 Parameters
112 ----------
113 nSources : `int`
114 Number of sources to create.
115 diaObjectIds : `numpy.ndarray`
116 Integer Ids of diaobjects to "associate" with the DiaSources.
117 exposure : `lsst.afw.image.Exposure`
118 Exposure to create sources over.
120 Returns
121 -------
122 diaSources : `pandas.DataFrame`
123 DiaSources generated across the exposure.
124 """
125 bbox = geom.Box2D(exposure.getBBox())
126 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources)
127 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources)
129 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
130 system=dafBase.DateTime.MJD)
132 wcs = exposure.getWcs()
133 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId()
135 data = []
136 for idx, (x, y) in enumerate(zip(rand_x, rand_y)):
137 coord = wcs.pixelToSky(x, y)
138 objId = diaObjectIds[idx % len(diaObjectIds)]
139 # Put together the minimum values for the alert.
140 data.append({"ra": coord.getRa().asDegrees(),
141 "decl": coord.getDec().asDegrees(),
142 "x": x,
143 "y": y,
144 "ccdVisitId": ccdVisitId,
145 "diaObjectId": objId,
146 "ssObjectId": 0,
147 "parentDiaSourceId": 0,
148 "prv_procOrder": 0,
149 "diaSourceId": idx + 1,
150 "midPointTai": midPointTaiMJD + 1.0 * idx,
151 "filterName": exposure.getFilterLabel().bandLabel,
152 "psNdata": 0,
153 "trailNdata": 0,
154 "dipNdata": 0,
155 "flags": 1})
157 return pd.DataFrame(data=data)
160def makeDiaForcedSources(nSources, diaObjectIds, exposure):
161 """Make a test set of DiaSources.
163 Parameters
164 ----------
165 nSources : `int`
166 Number of sources to create.
167 diaObjectIds : `numpy.ndarray`
168 Integer Ids of diaobjects to "associate" with the DiaSources.
169 exposure : `lsst.afw.image.Exposure`
170 Exposure to create sources over.
172 Returns
173 -------
174 diaSources : `pandas.DataFrame`
175 DiaSources generated across the exposure.
176 """
177 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
178 system=dafBase.DateTime.MJD)
180 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId()
182 data = []
183 for idx in range(nSources):
184 objId = diaObjectIds[idx % len(diaObjectIds)]
185 # Put together the minimum values for the alert.
186 data.append({"diaForcedSourceId": idx + 1,
187 "ccdVisitId": ccdVisitId + idx,
188 "diaObjectId": objId,
189 "midPointTai": midPointTaiMJD + 1.0 * idx,
190 "filterName": exposure.getFilterLabel().bandLabel,
191 "flags": 0})
193 return pd.DataFrame(data=data)
196def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
197 """Run object and source catalogs through the Apdb to get the correct
198 table schemas.
200 Parameters
201 ----------
202 objects : `pandas.DataFrame`
203 Set of test DiaObjects to round trip.
204 sources : `pandas.DataFrame`
205 Set of test DiaSources to round trip.
206 forcedSources : `pandas.DataFrame`
207 Set of test DiaForcedSources to round trip.
208 dateTime : `lsst.daf.base.DateTime`
209 Time for the Apdb.
211 Returns
212 -------
213 objects : `pandas.DataFrame`
214 Round tripped objects.
215 sources : `pandas.DataFrame`
216 Round tripped sources.
217 """
218 tmpFile = tempfile.NamedTemporaryFile()
220 apdbConfig = ApdbSqlConfig()
221 apdbConfig.db_url = "sqlite:///" + tmpFile.name
222 apdbConfig.dia_object_index = "baseline"
223 apdbConfig.dia_object_columns = []
224 apdbConfig.schema_file = _data_file_name(
225 "apdb-schema.yaml", "dax_apdb")
226 apdbConfig.extra_schema_file = _data_file_name(
227 "apdb-ap-pipe-schema-extra.yaml", "ap_association")
229 apdb = ApdbSql(config=apdbConfig)
230 apdb.makeSchema()
232 wholeSky = Box.full()
233 diaObjects = apdb.getDiaObjects(wholeSky).append(objects)
234 diaSources = apdb.getDiaSources(wholeSky, [], dateTime).append(sources)
235 diaForcedSources = apdb.getDiaForcedSources(wholeSky, [], dateTime).append(forcedSources)
237 apdb.store(dateTime, diaObjects, diaSources, diaForcedSources)
239 diaObjects = apdb.getDiaObjects(wholeSky)
240 diaSources = apdb.getDiaSources(wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime)
241 diaForcedSources = apdb.getDiaForcedSources(
242 wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime)
244 diaObjects.set_index("diaObjectId", drop=False, inplace=True)
245 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"],
246 drop=False,
247 inplace=True)
248 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True)
250 return (diaObjects, diaSources, diaForcedSources)
253class TestPackageAlerts(lsst.utils.tests.TestCase):
255 def setUp(self):
256 np.random.seed(1234)
257 self.cutoutSize = 35
258 self.center = lsst.geom.Point2D(50.1, 49.8)
259 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
260 lsst.geom.Extent2I(140, 160))
261 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
262 self.dataset.addSource(100000.0, self.center)
263 exposure, catalog = self.dataset.realize(
264 10.0,
265 self.dataset.makeMinimalSchema(),
266 randomSeed=0)
267 self.exposure = exposure
268 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector
269 self.exposure.setDetector(detector)
271 visit = afwImage.VisitInfo(
272 exposureId=1234,
273 exposureTime=200.,
274 date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
275 dafBase.DateTime.Timescale.TAI))
276 self.exposure.info.id = 1234
277 self.exposure.getInfo().setVisitInfo(visit)
279 self.exposure.setFilterLabel(afwImage.FilterLabel(band='g', physical="g.MP9401"))
281 diaObjects = makeDiaObjects(2, self.exposure)
282 diaSourceHistory = makeDiaSources(10,
283 diaObjects["diaObjectId"],
284 self.exposure)
285 diaForcedSources = makeDiaForcedSources(10,
286 diaObjects["diaObjectId"],
287 self.exposure)
288 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb(
289 diaObjects,
290 diaSourceHistory,
291 diaForcedSources,
292 self.exposure.getInfo().getVisitInfo().getDate())
293 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True)
294 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True)
295 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True)
296 diaSourceHistory["programId"] = 0
298 self.diaSources = diaSourceHistory.loc[
299 [(1, "g", 9), (2, "g", 10)], :]
300 self.diaSources["bboxSize"] = self.cutoutSize
301 self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9),
302 (2, "g", 10)])
304 self.cutoutWcs = wcs.WCS(naxis=2)
305 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]]
306 self.cutoutWcs.wcs.crval = [
307 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(),
308 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()]
309 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
310 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
312 def testCreateExtent(self):
313 """Test the extent creation for the cutout bbox.
314 """
315 packConfig = PackageAlertsConfig()
316 # Just create a minimum less than the default cutout.
317 packConfig.minCutoutSize = self.cutoutSize - 5
318 packageAlerts = PackageAlertsTask(config=packConfig)
319 extent = packageAlerts.createDiaSourceExtent(
320 packConfig.minCutoutSize - 5)
321 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize,
322 packConfig.minCutoutSize))
323 # Test that the cutout size is correct.
324 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize)
325 self.assertTrue(extent == geom.Extent2I(self.cutoutSize,
326 self.cutoutSize))
328 def testCreateCcdDataCutout(self):
329 """Test that the data is being extracted into the CCDData cutout
330 correctly.
331 """
332 packageAlerts = PackageAlertsTask()
334 diaSrcId = 1234
335 ccdData = packageAlerts.createCcdDataCutout(
336 self.exposure,
337 self.exposure.getWcs().getSkyOrigin(),
338 self.exposure.getBBox().getDimensions(),
339 self.exposure.getPhotoCalib(),
340 diaSrcId)
341 calibExposure = self.exposure.getPhotoCalib().calibrateImage(
342 self.exposure.getMaskedImage())
344 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd,
345 self.cutoutWcs.wcs.cd)
346 self.assertFloatsAlmostEqual(ccdData.data,
347 calibExposure.getImage().array)
349 ccdData = packageAlerts.createCcdDataCutout(
350 self.exposure,
351 geom.SpherePoint(0, 0, geom.degrees),
352 self.exposure.getBBox().getDimensions(),
353 self.exposure.getPhotoCalib(),
354 diaSrcId)
355 self.assertTrue(ccdData is None)
357 def testMakeLocalTransformMatrix(self):
358 """Test that the local WCS approximation is correct.
359 """
360 packageAlerts = PackageAlertsTask()
362 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
363 cutout = self.exposure.getCutout(sphPoint,
364 geom.Extent2I(self.cutoutSize,
365 self.cutoutSize))
366 cd = packageAlerts.makeLocalTransformMatrix(
367 cutout.getWcs(), self.center, sphPoint)
368 self.assertFloatsAlmostEqual(
369 cd,
370 cutout.getWcs().getCdMatrix(),
371 rtol=1e-11,
372 atol=1e-11)
374 def testStreamCcdDataToBytes(self):
375 """Test round tripping an CCDData cutout to bytes and back.
376 """
377 packageAlerts = PackageAlertsTask()
379 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
380 cutout = self.exposure.getCutout(sphPoint,
381 geom.Extent2I(self.cutoutSize,
382 self.cutoutSize))
383 cutoutCcdData = CCDData(
384 data=cutout.getImage().array,
385 wcs=self.cutoutWcs,
386 unit="adu")
388 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData)
389 with io.BytesIO(cutoutBytes) as bytesIO:
390 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
391 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data)
393 def testMakeAlertDict(self):
394 """Test stripping data from the various data products and into a
395 dictionary "alert".
396 """
397 packageAlerts = PackageAlertsTask()
398 alertId = 1234
400 for srcIdx, diaSource in self.diaSources.iterrows():
401 sphPoint = geom.SpherePoint(diaSource["ra"],
402 diaSource["decl"],
403 geom.degrees)
404 cutout = self.exposure.getCutout(sphPoint,
405 geom.Extent2I(self.cutoutSize,
406 self.cutoutSize))
407 ccdCutout = packageAlerts.createCcdDataCutout(
408 cutout,
409 sphPoint,
410 geom.Extent2I(self.cutoutSize, self.cutoutSize),
411 cutout.getPhotoCalib(),
412 1234)
413 cutoutBytes = packageAlerts.streamCcdDataToBytes(
414 ccdCutout)
415 objSources = self.diaSourceHistory.loc[srcIdx[0]]
416 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
417 alert = packageAlerts.makeAlertDict(
418 alertId,
419 diaSource,
420 self.diaObjects.loc[srcIdx[0]],
421 objSources,
422 objForcedSources,
423 ccdCutout,
424 ccdCutout)
425 self.assertEqual(len(alert), 9)
427 self.assertEqual(alert["alertId"], alertId)
428 self.assertEqual(alert["diaSource"], diaSource.to_dict())
429 self.assertEqual(alert["cutoutDifference"],
430 cutoutBytes)
431 self.assertEqual(alert["cutoutTemplate"],
432 cutoutBytes)
434 def testRun(self):
435 """Test the run method of package alerts.
436 """
437 packConfig = PackageAlertsConfig()
438 tempdir = tempfile.mkdtemp(prefix='alerts')
439 packConfig.alertWriteLocation = tempdir
440 packageAlerts = PackageAlertsTask(config=packConfig)
442 packageAlerts.run(self.diaSources,
443 self.diaObjects,
444 self.diaSourceHistory,
445 self.diaForcedSources,
446 self.exposure,
447 self.exposure,
448 None)
450 ccdVisitId = self.exposure.getInfo().getVisitInfo().getExposureId()
451 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f:
452 writer_schema, data_stream = \
453 packageAlerts.alertSchema.retrieve_alerts(f)
454 data = list(data_stream)
455 self.assertEqual(len(data), len(self.diaSources))
456 for idx, alert in enumerate(data):
457 for key, value in alert["diaSource"].items():
458 if isinstance(value, float):
459 if np.isnan(self.diaSources.iloc[idx][key]):
460 self.assertTrue(np.isnan(value))
461 else:
462 self.assertAlmostEqual(
463 1 - value / self.diaSources.iloc[idx][key],
464 0.)
465 else:
466 self.assertEqual(value, self.diaSources.iloc[idx][key])
467 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
468 alert["diaSource"]["decl"],
469 geom.degrees)
470 cutout = self.exposure.getCutout(sphPoint,
471 geom.Extent2I(self.cutoutSize,
472 self.cutoutSize))
473 ccdCutout = packageAlerts.createCcdDataCutout(
474 cutout,
475 sphPoint,
476 geom.Extent2I(self.cutoutSize, self.cutoutSize),
477 cutout.getPhotoCalib(),
478 1234)
479 self.assertEqual(alert["cutoutDifference"],
480 packageAlerts.streamCcdDataToBytes(ccdCutout))
482 shutil.rmtree(tempdir)
485class MemoryTester(lsst.utils.tests.MemoryTestCase):
486 pass
489def setup_module(module):
490 lsst.utils.tests.init()
493if __name__ == "__main__": 493 ↛ 494line 493 didn't jump to line 494, because the condition on line 493 was never true
494 lsst.utils.tests.init()
495 unittest.main()