Coverage for tests/test_packageAlerts.py : 17%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import io
23import os
24import numpy as np
25import pandas as pd
26import shutil
27import tempfile
28import unittest
30from astropy import wcs
31from astropy.nddata import CCDData
33from lsst.ap.association import (PackageAlertsConfig,
34 PackageAlertsTask,
35 make_dia_source_schema,
36 make_dia_object_schema)
37from lsst.afw.cameraGeom.testUtils import DetectorWrapper
38import lsst.afw.image as afwImage
39import lsst.afw.image.utils as afwImageUtils
40import lsst.daf.base as dafBase
41from lsst.dax.apdb import Apdb, ApdbConfig
42import lsst.geom as geom
43import lsst.meas.base.tests
44from lsst.utils import getPackageDir
45import lsst.utils.tests
48def _data_file_name(basename, module_name):
49 """Return path name of a data file.
51 Parameters
52 ----------
53 basename : `str`
54 Name of the file to add to the path string.
55 module_name : `str`
56 Name of lsst stack package environment variable.
58 Returns
59 -------
60 data_file_path : `str`
61 Full path of the file to load from the "data" directory in a given
62 repository.
63 """
64 return os.path.join(getPackageDir(module_name), "data", basename)
67def makeDiaObjects(nObjects, exposure):
68 """Make a test set of DiaObjects.
70 Parameters
71 ----------
72 nObjects : `int`
73 Number of objects to create.
74 exposure : `lsst.afw.image.Exposure`
75 Exposure to create objects over.
77 Returns
78 -------
79 diaObjects : `pandas.DataFrame`
80 DiaObjects generated across the exposure.
81 """
82 bbox = geom.Box2D(exposure.getBBox())
83 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects)
84 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects)
86 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
87 system=dafBase.DateTime.MJD)
89 wcs = exposure.getWcs()
91 data = []
92 for idx, (x, y) in enumerate(zip(rand_x, rand_y)):
93 coord = wcs.pixelToSky(x, y)
94 htmIdx = 1
95 newObject = {"ra": coord.getRa().asDegrees(),
96 "decl": coord.getDec().asDegrees(),
97 "radecTai": midPointTaiMJD,
98 "diaObjectId": idx,
99 "pixelId": htmIdx,
100 "pmParallaxNdata": 0,
101 "nearbyObj1": 0,
102 "nearbyObj2": 0,
103 "nearbyObj3": 0,
104 "flags": 1,
105 "nDiaSources": 5}
106 for f in ["u", "g", "r", "i", "z", "y"]:
107 newObject["%sPSFluxNdata" % f] = 0
108 data.append(newObject)
110 return pd.DataFrame(data=data)
113def makeDiaSources(nSources, diaObjectIds, exposure):
114 """Make a test set of DiaSources.
116 Parameters
117 ----------
118 nSources : `int`
119 Number of sources to create.
120 diaObjectIds : `numpy.ndarray`
121 Integer Ids of diaobjects to "associate" with the DiaSources.
122 exposure : `lsst.afw.image.Exposure`
123 Exposure to create sources over.
124 pixelator : `lsst.sphgeom.HtmPixelization`
125 Object to compute spatial indicies from.
127 Returns
128 -------
129 diaSources : `pandas.DataFrame`
130 DiaSources generated across the exposure.
131 """
132 bbox = geom.Box2D(exposure.getBBox())
133 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources)
134 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources)
136 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
137 system=dafBase.DateTime.MJD)
139 wcs = exposure.getWcs()
140 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId()
142 data = []
143 for idx, (x, y) in enumerate(zip(rand_x, rand_y)):
144 coord = wcs.pixelToSky(x, y)
145 htmIdx = 1
146 objId = diaObjectIds[idx % len(diaObjectIds)]
147 # Put together the minimum values for the alert.
148 data.append({"ra": coord.getRa().asDegrees(),
149 "decl": coord.getDec().asDegrees(),
150 "x": x,
151 "y": y,
152 "ccdVisitId": ccdVisitId,
153 "diaObjectId": objId,
154 "ssObjectId": 0,
155 "parentDiaSourceId": 0,
156 "prv_procOrder": 0,
157 "diaSourceId": idx,
158 "pixelId": htmIdx,
159 "midPointTai": midPointTaiMJD + 1.0 * idx,
160 "filterName": exposure.getFilter().getCanonicalName(),
161 "filterId": 0,
162 "psNdata": 0,
163 "trailNdata": 0,
164 "dipNdata": 0,
165 "flags": 1})
167 return pd.DataFrame(data=data)
170def makeDiaForcedSources(nSources, diaObjectIds, exposure):
171 """Make a test set of DiaSources.
173 Parameters
174 ----------
175 nSources : `int`
176 Number of sources to create.
177 diaObjectIds : `numpy.ndarray`
178 Integer Ids of diaobjects to "associate" with the DiaSources.
179 exposure : `lsst.afw.image.Exposure`
180 Exposure to create sources over.
182 Returns
183 -------
184 diaSources : `pandas.DataFrame`
185 DiaSources generated across the exposure.
186 """
187 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
188 system=dafBase.DateTime.MJD)
190 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId()
192 data = []
193 for idx in range(nSources):
194 objId = diaObjectIds[idx % len(diaObjectIds)]
195 # Put together the minimum values for the alert.
196 data.append({"diaForcedSourceId": idx,
197 "ccdVisitId": ccdVisitId + idx,
198 "diaObjectId": objId,
199 "midPointTai": midPointTaiMJD + 1.0 * idx,
200 "filterName": exposure.getFilter().getCanonicalName(),
201 "flags": 0})
203 return pd.DataFrame(data=data)
206def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
207 """Run object and source catalogs through the Apdb to get the correct
208 table schemas.
210 Parameters
211 ----------
212 objects : `pandas.DataFrame`
213 Set of test DiaObjects to round trip.
214 sources : `pandas.DataFrame`
215 Set of test DiaSources to round trip.
216 forcedSources : `pandas.DataFrame`
217 Set of test DiaForcedSources to round trip.
218 dateTime : `datetime.datetime`
219 Time for the Apdb.
221 Returns
222 -------
223 objects : `pandas.DataFrame`
224 Round tripped objects.
225 sources : `pandas.DataFrame`
226 Round tripped sources.
227 """
228 tmpFile = tempfile.NamedTemporaryFile()
230 apdbConfig = ApdbConfig()
231 apdbConfig.db_url = "sqlite:///" + tmpFile.name
232 apdbConfig.isolation_level = "READ_UNCOMMITTED"
233 apdbConfig.dia_object_index = "baseline"
234 apdbConfig.dia_object_columns = []
235 apdbConfig.schema_file = _data_file_name(
236 "apdb-schema.yaml", "dax_apdb")
237 apdbConfig.column_map = _data_file_name(
238 "apdb-ap-pipe-afw-map.yaml", "ap_association")
239 apdbConfig.extra_schema_file = _data_file_name(
240 "apdb-ap-pipe-schema-extra.yaml", "ap_association")
242 apdb = Apdb(config=apdbConfig,
243 afw_schemas=dict(DiaObject=make_dia_object_schema(),
244 DiaSource=make_dia_source_schema()))
245 apdb.makeSchema()
247 minId = objects["pixelId"].min()
248 maxId = objects["pixelId"].max()
249 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True).append(objects)
250 diaSources = apdb.getDiaSources(np.unique(objects["diaObjectId"]),
251 dateTime,
252 return_pandas=True).append(sources)
253 diaForcedSources = apdb.getDiaForcedSources(
254 np.unique(objects["diaObjectId"]),
255 dateTime,
256 return_pandas=True).append(forcedSources)
258 apdb.storeDiaSources(diaSources)
259 apdb.storeDiaForcedSources(diaForcedSources)
260 apdb.storeDiaObjects(diaObjects, dateTime)
262 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True)
263 diaSources = apdb.getDiaSources(np.unique(diaObjects["diaObjectId"]),
264 dateTime,
265 return_pandas=True)
266 diaForcedSources = apdb.getDiaForcedSources(
267 np.unique(diaObjects["diaObjectId"]),
268 dateTime,
269 return_pandas=True)
271 diaObjects.set_index("diaObjectId", drop=False, inplace=True)
272 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"],
273 drop=False,
274 inplace=True)
275 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True)
277 return (diaObjects, diaSources, diaForcedSources)
280class TestPackageAlerts(lsst.utils.tests.TestCase):
282 def setUp(self):
283 np.random.seed(1234)
284 self.cutoutSize = 35
285 self.center = lsst.geom.Point2D(50.1, 49.8)
286 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
287 lsst.geom.Extent2I(140, 160))
288 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
289 self.dataset.addSource(100000.0, self.center)
290 exposure, catalog = self.dataset.realize(
291 10.0,
292 self.dataset.makeMinimalSchema(),
293 randomSeed=0)
294 self.exposure = exposure
295 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector
296 self.exposure.setDetector(detector)
298 visit = afwImage.VisitInfo(
299 exposureId=1234,
300 exposureTime=200.,
301 date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
302 dafBase.DateTime.Timescale.TAI))
303 self.exposure.getInfo().setVisitInfo(visit)
305 self.filter_names = ["g"]
306 afwImageUtils.resetFilters()
307 afwImageUtils.defineFilter('g', lambdaEff=487, alias="g.MP9401")
308 self.exposure.setFilter(afwImage.Filter('g'))
310 diaObjects = makeDiaObjects(2, self.exposure)
311 diaSourceHistory = makeDiaSources(10,
312 diaObjects["diaObjectId"],
313 self.exposure)
314 diaForcedSources = makeDiaForcedSources(10,
315 diaObjects["diaObjectId"],
316 self.exposure)
317 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb(
318 diaObjects,
319 diaSourceHistory,
320 diaForcedSources,
321 self.exposure.getInfo().getVisitInfo().getDate().toPython())
322 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True)
323 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True)
324 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True)
325 diaSourceHistory["programId"] = 0
327 self.diaSources = diaSourceHistory.loc[
328 [(0, "g", 8), (1, "g", 9)], :]
329 self.diaSources["bboxSize"] = self.cutoutSize
330 self.diaSourceHistory = diaSourceHistory.drop(labels=[(0, "g", 8),
331 (1, "g", 9)])
333 self.cutoutWcs = wcs.WCS(naxis=2)
334 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]]
335 self.cutoutWcs.wcs.crval = [
336 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(),
337 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()]
338 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
339 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
341 def testCreateBBox(self):
342 """Test the bbox creation
343 """
344 packConfig = PackageAlertsConfig()
345 # Just create a minimum less than the default cutout.
346 packConfig.minCutoutSize = self.cutoutSize - 5
347 packageAlerts = PackageAlertsTask(config=packConfig)
348 bbox = packageAlerts.createDiaSourceBBox(packConfig.minCutoutSize - 5)
349 self.assertTrue(bbox == geom.Extent2I(packConfig.minCutoutSize,
350 packConfig.minCutoutSize))
351 # Test that the cutout size is correct.
352 bbox = packageAlerts.createDiaSourceBBox(self.cutoutSize)
353 self.assertTrue(bbox == geom.Extent2I(self.cutoutSize,
354 self.cutoutSize))
356 def testCreateCcdDataCutout(self):
357 """Test that the data is being extracted into the CCDData cutout
358 correctly.
359 """
360 packageAlerts = PackageAlertsTask()
362 ccdData = packageAlerts.createCcdDataCutout(
363 self.exposure,
364 self.exposure.getWcs().getSkyOrigin(),
365 self.exposure.getPhotoCalib())
366 calibExposure = self.exposure.getPhotoCalib().calibrateImage(
367 self.exposure.getMaskedImage())
369 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd,
370 self.cutoutWcs.wcs.cd)
371 self.assertFloatsAlmostEqual(ccdData.data,
372 calibExposure.getImage().array)
374 def testMakeLocalTransformMatrix(self):
375 """Test that the local WCS approximation is correct.
376 """
377 packageAlerts = PackageAlertsTask()
379 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
380 cutout = self.exposure.getCutout(sphPoint,
381 geom.Extent2I(self.cutoutSize,
382 self.cutoutSize))
383 cd = packageAlerts.makeLocalTransformMatrix(
384 cutout.getWcs(), self.center, sphPoint)
385 self.assertFloatsAlmostEqual(
386 cd,
387 cutout.getWcs().getCdMatrix(),
388 rtol=1e-11,
389 atol=1e-11)
391 def testStreamCcdDataToBytes(self):
392 """Test round tripping an CCDData cutout to bytes and back.
393 """
394 packageAlerts = PackageAlertsTask()
396 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
397 cutout = self.exposure.getCutout(sphPoint,
398 geom.Extent2I(self.cutoutSize,
399 self.cutoutSize))
400 cutoutCcdData = CCDData(
401 data=cutout.getImage().array,
402 wcs=self.cutoutWcs,
403 unit="adu")
405 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData)
406 with io.BytesIO(cutoutBytes) as bytesIO:
407 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
408 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data)
410 def testMakeAlertDict(self):
411 """Test stripping data from the various data products and into a
412 dictionary "alert".
413 """
414 packageAlerts = PackageAlertsTask()
415 alertId = 1234
417 for srcIdx, diaSource in self.diaSources.iterrows():
418 sphPoint = geom.SpherePoint(diaSource["ra"],
419 diaSource["decl"],
420 geom.degrees)
421 cutout = self.exposure.getCutout(sphPoint,
422 geom.Extent2I(self.cutoutSize,
423 self.cutoutSize))
424 ccdCutout = packageAlerts.createCcdDataCutout(
425 cutout, sphPoint, cutout.getPhotoCalib())
426 cutoutBytes = packageAlerts.streamCcdDataToBytes(
427 ccdCutout)
428 objSources = self.diaSourceHistory.loc[srcIdx[0]]
429 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
430 alert = packageAlerts.makeAlertDict(
431 alertId,
432 diaSource,
433 self.diaObjects.loc[srcIdx[0]],
434 objSources,
435 objForcedSources,
436 ccdCutout,
437 None)
438 self.assertEqual(len(alert), 9)
440 self.assertEqual(alert["alertId"], alertId)
441 self.assertEqual(alert["diaSource"], diaSource.to_dict())
442 self.assertEqual(alert["cutoutDifference"],
443 cutoutBytes)
444 self.assertEqual(alert["cutoutTemplate"],
445 None)
447 def testRun(self):
448 """Test the run method of package alerts.
449 """
450 packConfig = PackageAlertsConfig()
451 tempdir = tempfile.mkdtemp(prefix='alerts')
452 packConfig.alertWriteLocation = tempdir
453 packageAlerts = PackageAlertsTask(config=packConfig)
455 packageAlerts.run(self.diaSources,
456 self.diaObjects,
457 self.diaSourceHistory,
458 self.diaForcedSources,
459 self.exposure,
460 None,
461 None)
463 ccdVisitId = self.exposure.getInfo().getVisitInfo().getExposureId()
464 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f:
465 writer_schema, data = \
466 packageAlerts.alertSchema.retrieve_alerts(f)
467 self.assertEqual(len(data), len(self.diaSources))
468 for idx, alert in enumerate(data):
469 for key, value in alert["diaSource"].items():
470 if isinstance(value, float):
471 if np.isnan(self.diaSources.iloc[idx][key]):
472 self.assertTrue(np.isnan(value))
473 else:
474 self.assertAlmostEqual(
475 1 - value / self.diaSources.iloc[idx][key],
476 0.)
477 else:
478 self.assertEqual(value, self.diaSources.iloc[idx][key])
479 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
480 alert["diaSource"]["decl"],
481 geom.degrees)
482 cutout = self.exposure.getCutout(sphPoint,
483 geom.Extent2I(self.cutoutSize,
484 self.cutoutSize))
485 ccdCutout = packageAlerts.createCcdDataCutout(
486 cutout, sphPoint, cutout.getPhotoCalib())
487 self.assertEqual(alert["cutoutDifference"],
488 packageAlerts.streamCcdDataToBytes(ccdCutout))
490 shutil.rmtree(tempdir)
493class MemoryTester(lsst.utils.tests.MemoryTestCase):
494 pass
497def setup_module(module):
498 lsst.utils.tests.init()
501if __name__ == "__main__": 501 ↛ 502line 501 didn't jump to line 502, because the condition on line 501 was never true
502 lsst.utils.tests.init()
503 unittest.main()