Coverage for tests/test_packageAlerts.py : 17%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import io
23import os
24import numpy as np
25import pandas as pd
26import shutil
27import tempfile
28import unittest
30from astropy import wcs
31from astropy.nddata import CCDData
33from lsst.ap.association import (PackageAlertsConfig,
34 PackageAlertsTask,
35 make_dia_source_schema,
36 make_dia_object_schema)
37from lsst.afw.cameraGeom.testUtils import DetectorWrapper
38import lsst.afw.image as afwImage
39import lsst.daf.base as dafBase
40from lsst.dax.apdb import Apdb, ApdbConfig
41import lsst.geom as geom
42import lsst.meas.base.tests
43from lsst.utils import getPackageDir
44import lsst.utils.tests
47def _data_file_name(basename, module_name):
48 """Return path name of a data file.
50 Parameters
51 ----------
52 basename : `str`
53 Name of the file to add to the path string.
54 module_name : `str`
55 Name of lsst stack package environment variable.
57 Returns
58 -------
59 data_file_path : `str`
60 Full path of the file to load from the "data" directory in a given
61 repository.
62 """
63 return os.path.join(getPackageDir(module_name), "data", basename)
66def makeDiaObjects(nObjects, exposure):
67 """Make a test set of DiaObjects.
69 Parameters
70 ----------
71 nObjects : `int`
72 Number of objects to create.
73 exposure : `lsst.afw.image.Exposure`
74 Exposure to create objects over.
76 Returns
77 -------
78 diaObjects : `pandas.DataFrame`
79 DiaObjects generated across the exposure.
80 """
81 bbox = geom.Box2D(exposure.getBBox())
82 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects)
83 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects)
85 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
86 system=dafBase.DateTime.MJD)
88 wcs = exposure.getWcs()
90 data = []
91 for idx, (x, y) in enumerate(zip(rand_x, rand_y)):
92 coord = wcs.pixelToSky(x, y)
93 htmIdx = 1
94 newObject = {"ra": coord.getRa().asDegrees(),
95 "decl": coord.getDec().asDegrees(),
96 "radecTai": midPointTaiMJD,
97 "diaObjectId": idx,
98 "pixelId": htmIdx,
99 "pmParallaxNdata": 0,
100 "nearbyObj1": 0,
101 "nearbyObj2": 0,
102 "nearbyObj3": 0,
103 "flags": 1,
104 "nDiaSources": 5}
105 for f in ["u", "g", "r", "i", "z", "y"]:
106 newObject["%sPSFluxNdata" % f] = 0
107 data.append(newObject)
109 return pd.DataFrame(data=data)
112def makeDiaSources(nSources, diaObjectIds, exposure):
113 """Make a test set of DiaSources.
115 Parameters
116 ----------
117 nSources : `int`
118 Number of sources to create.
119 diaObjectIds : `numpy.ndarray`
120 Integer Ids of diaobjects to "associate" with the DiaSources.
121 exposure : `lsst.afw.image.Exposure`
122 Exposure to create sources over.
123 pixelator : `lsst.sphgeom.HtmPixelization`
124 Object to compute spatial indicies from.
126 Returns
127 -------
128 diaSources : `pandas.DataFrame`
129 DiaSources generated across the exposure.
130 """
131 bbox = geom.Box2D(exposure.getBBox())
132 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources)
133 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources)
135 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
136 system=dafBase.DateTime.MJD)
138 wcs = exposure.getWcs()
139 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId()
141 data = []
142 for idx, (x, y) in enumerate(zip(rand_x, rand_y)):
143 coord = wcs.pixelToSky(x, y)
144 htmIdx = 1
145 objId = diaObjectIds[idx % len(diaObjectIds)]
146 # Put together the minimum values for the alert.
147 data.append({"ra": coord.getRa().asDegrees(),
148 "decl": coord.getDec().asDegrees(),
149 "x": x,
150 "y": y,
151 "ccdVisitId": ccdVisitId,
152 "diaObjectId": objId,
153 "ssObjectId": 0,
154 "parentDiaSourceId": 0,
155 "prv_procOrder": 0,
156 "diaSourceId": idx,
157 "pixelId": htmIdx,
158 "midPointTai": midPointTaiMJD + 1.0 * idx,
159 "filterName": exposure.getFilterLabel().bandLabel,
160 "psNdata": 0,
161 "trailNdata": 0,
162 "dipNdata": 0,
163 "flags": 1})
165 return pd.DataFrame(data=data)
168def makeDiaForcedSources(nSources, diaObjectIds, exposure):
169 """Make a test set of DiaSources.
171 Parameters
172 ----------
173 nSources : `int`
174 Number of sources to create.
175 diaObjectIds : `numpy.ndarray`
176 Integer Ids of diaobjects to "associate" with the DiaSources.
177 exposure : `lsst.afw.image.Exposure`
178 Exposure to create sources over.
180 Returns
181 -------
182 diaSources : `pandas.DataFrame`
183 DiaSources generated across the exposure.
184 """
185 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
186 system=dafBase.DateTime.MJD)
188 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId()
190 data = []
191 for idx in range(nSources):
192 objId = diaObjectIds[idx % len(diaObjectIds)]
193 # Put together the minimum values for the alert.
194 data.append({"diaForcedSourceId": idx,
195 "ccdVisitId": ccdVisitId + idx,
196 "diaObjectId": objId,
197 "midPointTai": midPointTaiMJD + 1.0 * idx,
198 "filterName": exposure.getFilterLabel().bandLabel,
199 "flags": 0})
201 return pd.DataFrame(data=data)
204def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
205 """Run object and source catalogs through the Apdb to get the correct
206 table schemas.
208 Parameters
209 ----------
210 objects : `pandas.DataFrame`
211 Set of test DiaObjects to round trip.
212 sources : `pandas.DataFrame`
213 Set of test DiaSources to round trip.
214 forcedSources : `pandas.DataFrame`
215 Set of test DiaForcedSources to round trip.
216 dateTime : `datetime.datetime`
217 Time for the Apdb.
219 Returns
220 -------
221 objects : `pandas.DataFrame`
222 Round tripped objects.
223 sources : `pandas.DataFrame`
224 Round tripped sources.
225 """
226 tmpFile = tempfile.NamedTemporaryFile()
228 apdbConfig = ApdbConfig()
229 apdbConfig.db_url = "sqlite:///" + tmpFile.name
230 apdbConfig.isolation_level = "READ_UNCOMMITTED"
231 apdbConfig.dia_object_index = "baseline"
232 apdbConfig.dia_object_columns = []
233 apdbConfig.schema_file = _data_file_name(
234 "apdb-schema.yaml", "dax_apdb")
235 apdbConfig.column_map = _data_file_name(
236 "apdb-ap-pipe-afw-map.yaml", "ap_association")
237 apdbConfig.extra_schema_file = _data_file_name(
238 "apdb-ap-pipe-schema-extra.yaml", "ap_association")
240 apdb = Apdb(config=apdbConfig,
241 afw_schemas=dict(DiaObject=make_dia_object_schema(),
242 DiaSource=make_dia_source_schema()))
243 apdb.makeSchema()
245 minId = objects["pixelId"].min()
246 maxId = objects["pixelId"].max()
247 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True).append(objects)
248 diaSources = apdb.getDiaSources(np.unique(objects["diaObjectId"]),
249 dateTime,
250 return_pandas=True).append(sources)
251 diaForcedSources = apdb.getDiaForcedSources(
252 np.unique(objects["diaObjectId"]),
253 dateTime,
254 return_pandas=True).append(forcedSources)
256 apdb.storeDiaSources(diaSources)
257 apdb.storeDiaForcedSources(diaForcedSources)
258 apdb.storeDiaObjects(diaObjects, dateTime)
260 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True)
261 diaSources = apdb.getDiaSources(np.unique(diaObjects["diaObjectId"]),
262 dateTime,
263 return_pandas=True)
264 diaForcedSources = apdb.getDiaForcedSources(
265 np.unique(diaObjects["diaObjectId"]),
266 dateTime,
267 return_pandas=True)
269 diaObjects.set_index("diaObjectId", drop=False, inplace=True)
270 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"],
271 drop=False,
272 inplace=True)
273 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True)
275 return (diaObjects, diaSources, diaForcedSources)
278class TestPackageAlerts(lsst.utils.tests.TestCase):
280 def setUp(self):
281 np.random.seed(1234)
282 self.cutoutSize = 35
283 self.center = lsst.geom.Point2D(50.1, 49.8)
284 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
285 lsst.geom.Extent2I(140, 160))
286 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
287 self.dataset.addSource(100000.0, self.center)
288 exposure, catalog = self.dataset.realize(
289 10.0,
290 self.dataset.makeMinimalSchema(),
291 randomSeed=0)
292 self.exposure = exposure
293 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector
294 self.exposure.setDetector(detector)
296 visit = afwImage.VisitInfo(
297 exposureId=1234,
298 exposureTime=200.,
299 date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
300 dafBase.DateTime.Timescale.TAI))
301 self.exposure.getInfo().setVisitInfo(visit)
303 self.exposure.setFilterLabel(afwImage.FilterLabel(band='g', physical="g.MP9401"))
305 diaObjects = makeDiaObjects(2, self.exposure)
306 diaSourceHistory = makeDiaSources(10,
307 diaObjects["diaObjectId"],
308 self.exposure)
309 diaForcedSources = makeDiaForcedSources(10,
310 diaObjects["diaObjectId"],
311 self.exposure)
312 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb(
313 diaObjects,
314 diaSourceHistory,
315 diaForcedSources,
316 self.exposure.getInfo().getVisitInfo().getDate().toPython())
317 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True)
318 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True)
319 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True)
320 diaSourceHistory["programId"] = 0
322 self.diaSources = diaSourceHistory.loc[
323 [(0, "g", 8), (1, "g", 9)], :]
324 self.diaSources["bboxSize"] = self.cutoutSize
325 self.diaSourceHistory = diaSourceHistory.drop(labels=[(0, "g", 8),
326 (1, "g", 9)])
328 self.cutoutWcs = wcs.WCS(naxis=2)
329 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]]
330 self.cutoutWcs.wcs.crval = [
331 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(),
332 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()]
333 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
334 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
336 def testCreateBBox(self):
337 """Test the bbox creation
338 """
339 packConfig = PackageAlertsConfig()
340 # Just create a minimum less than the default cutout.
341 packConfig.minCutoutSize = self.cutoutSize - 5
342 packageAlerts = PackageAlertsTask(config=packConfig)
343 bbox = packageAlerts.createDiaSourceBBox(packConfig.minCutoutSize - 5)
344 self.assertTrue(bbox == geom.Extent2I(packConfig.minCutoutSize,
345 packConfig.minCutoutSize))
346 # Test that the cutout size is correct.
347 bbox = packageAlerts.createDiaSourceBBox(self.cutoutSize)
348 self.assertTrue(bbox == geom.Extent2I(self.cutoutSize,
349 self.cutoutSize))
351 def testCreateCcdDataCutout(self):
352 """Test that the data is being extracted into the CCDData cutout
353 correctly.
354 """
355 packageAlerts = PackageAlertsTask()
357 ccdData = packageAlerts.createCcdDataCutout(
358 self.exposure,
359 self.exposure.getWcs().getSkyOrigin(),
360 self.exposure.getPhotoCalib())
361 calibExposure = self.exposure.getPhotoCalib().calibrateImage(
362 self.exposure.getMaskedImage())
364 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd,
365 self.cutoutWcs.wcs.cd)
366 self.assertFloatsAlmostEqual(ccdData.data,
367 calibExposure.getImage().array)
369 def testMakeLocalTransformMatrix(self):
370 """Test that the local WCS approximation is correct.
371 """
372 packageAlerts = PackageAlertsTask()
374 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
375 cutout = self.exposure.getCutout(sphPoint,
376 geom.Extent2I(self.cutoutSize,
377 self.cutoutSize))
378 cd = packageAlerts.makeLocalTransformMatrix(
379 cutout.getWcs(), self.center, sphPoint)
380 self.assertFloatsAlmostEqual(
381 cd,
382 cutout.getWcs().getCdMatrix(),
383 rtol=1e-11,
384 atol=1e-11)
386 def testStreamCcdDataToBytes(self):
387 """Test round tripping an CCDData cutout to bytes and back.
388 """
389 packageAlerts = PackageAlertsTask()
391 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
392 cutout = self.exposure.getCutout(sphPoint,
393 geom.Extent2I(self.cutoutSize,
394 self.cutoutSize))
395 cutoutCcdData = CCDData(
396 data=cutout.getImage().array,
397 wcs=self.cutoutWcs,
398 unit="adu")
400 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData)
401 with io.BytesIO(cutoutBytes) as bytesIO:
402 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
403 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data)
405 def testMakeAlertDict(self):
406 """Test stripping data from the various data products and into a
407 dictionary "alert".
408 """
409 packageAlerts = PackageAlertsTask()
410 alertId = 1234
412 for srcIdx, diaSource in self.diaSources.iterrows():
413 sphPoint = geom.SpherePoint(diaSource["ra"],
414 diaSource["decl"],
415 geom.degrees)
416 cutout = self.exposure.getCutout(sphPoint,
417 geom.Extent2I(self.cutoutSize,
418 self.cutoutSize))
419 ccdCutout = packageAlerts.createCcdDataCutout(
420 cutout, sphPoint, cutout.getPhotoCalib())
421 cutoutBytes = packageAlerts.streamCcdDataToBytes(
422 ccdCutout)
423 objSources = self.diaSourceHistory.loc[srcIdx[0]]
424 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
425 alert = packageAlerts.makeAlertDict(
426 alertId,
427 diaSource,
428 self.diaObjects.loc[srcIdx[0]],
429 objSources,
430 objForcedSources,
431 ccdCutout,
432 ccdCutout)
433 self.assertEqual(len(alert), 9)
435 self.assertEqual(alert["alertId"], alertId)
436 self.assertEqual(alert["diaSource"], diaSource.to_dict())
437 self.assertEqual(alert["cutoutDifference"],
438 cutoutBytes)
439 self.assertEqual(alert["cutoutTemplate"],
440 cutoutBytes)
442 def testRun(self):
443 """Test the run method of package alerts.
444 """
445 packConfig = PackageAlertsConfig()
446 tempdir = tempfile.mkdtemp(prefix='alerts')
447 packConfig.alertWriteLocation = tempdir
448 packageAlerts = PackageAlertsTask(config=packConfig)
450 packageAlerts.run(self.diaSources,
451 self.diaObjects,
452 self.diaSourceHistory,
453 self.diaForcedSources,
454 self.exposure,
455 self.exposure,
456 None)
458 ccdVisitId = self.exposure.getInfo().getVisitInfo().getExposureId()
459 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f:
460 writer_schema, data_stream = \
461 packageAlerts.alertSchema.retrieve_alerts(f)
462 data = list(data_stream)
463 self.assertEqual(len(data), len(self.diaSources))
464 for idx, alert in enumerate(data):
465 for key, value in alert["diaSource"].items():
466 if isinstance(value, float):
467 if np.isnan(self.diaSources.iloc[idx][key]):
468 self.assertTrue(np.isnan(value))
469 else:
470 self.assertAlmostEqual(
471 1 - value / self.diaSources.iloc[idx][key],
472 0.)
473 else:
474 self.assertEqual(value, self.diaSources.iloc[idx][key])
475 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
476 alert["diaSource"]["decl"],
477 geom.degrees)
478 cutout = self.exposure.getCutout(sphPoint,
479 geom.Extent2I(self.cutoutSize,
480 self.cutoutSize))
481 ccdCutout = packageAlerts.createCcdDataCutout(
482 cutout, sphPoint, cutout.getPhotoCalib())
483 self.assertEqual(alert["cutoutDifference"],
484 packageAlerts.streamCcdDataToBytes(ccdCutout))
486 shutil.rmtree(tempdir)
489class MemoryTester(lsst.utils.tests.MemoryTestCase):
490 pass
493def setup_module(module):
494 lsst.utils.tests.init()
497if __name__ == "__main__": 497 ↛ 498line 497 didn't jump to line 498, because the condition on line 497 was never true
498 lsst.utils.tests.init()
499 unittest.main()