Coverage for tests/test_packageAlerts.py : 17%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import io
23import os
24import numpy as np
25import pandas as pd
26import shutil
27import tempfile
28import unittest
30from astropy import wcs
31from astropy.nddata import CCDData
33from lsst.ap.association import (PackageAlertsConfig,
34 PackageAlertsTask,
35 make_dia_source_schema,
36 make_dia_object_schema)
37from lsst.afw.cameraGeom.testUtils import DetectorWrapper
38import lsst.afw.image as afwImage
39import lsst.afw.image.utils as afwImageUtils
40import lsst.daf.base as dafBase
41from lsst.dax.apdb import Apdb, ApdbConfig
42import lsst.geom as geom
43import lsst.meas.base.tests
44from lsst.utils import getPackageDir
45import lsst.utils.tests
48def _data_file_name(basename, module_name):
49 """Return path name of a data file.
51 Parameters
52 ----------
53 basename : `str`
54 Name of the file to add to the path string.
55 module_name : `str`
56 Name of lsst stack package environment variable.
58 Returns
59 -------
60 data_file_path : `str`
61 Full path of the file to load from the "data" directory in a given
62 repository.
63 """
64 return os.path.join(getPackageDir(module_name), "data", basename)
67def makeDiaObjects(nObjects, exposure):
68 """Make a test set of DiaObjects.
70 Parameters
71 ----------
72 nObjects : `int`
73 Number of objects to create.
74 exposure : `lsst.afw.image.Exposure`
75 Exposure to create objects over.
77 Returns
78 -------
79 diaObjects : `pandas.DataFrame`
80 DiaObjects generated across the exposure.
81 """
82 bbox = geom.Box2D(exposure.getBBox())
83 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects)
84 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects)
86 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
87 system=dafBase.DateTime.MJD)
89 wcs = exposure.getWcs()
91 data = []
92 for idx, (x, y) in enumerate(zip(rand_x, rand_y)):
93 coord = wcs.pixelToSky(x, y)
94 htmIdx = 1
95 newObject = {"ra": coord.getRa().asDegrees(),
96 "decl": coord.getDec().asDegrees(),
97 "radecTai": midPointTaiMJD,
98 "diaObjectId": idx,
99 "pixelId": htmIdx,
100 "pmParallaxNdata": 0,
101 "nearbyObj1": 0,
102 "nearbyObj2": 0,
103 "nearbyObj3": 0,
104 "flags": 1,
105 "nDiaSources": 5}
106 for f in ["u", "g", "r", "i", "z", "y"]:
107 newObject["%sPSFluxNdata" % f] = 0
108 data.append(newObject)
110 return pd.DataFrame(data=data)
113def makeDiaSources(nSources, diaObjectIds, exposure):
114 """Make a test set of DiaSources.
116 Parameters
117 ----------
118 nSources : `int`
119 Number of sources to create.
120 diaObjectIds : `numpy.ndarray`
121 Integer Ids of diaobjects to "associate" with the DiaSources.
122 exposure : `lsst.afw.image.Exposure`
123 Exposure to create sources over.
124 pixelator : `lsst.sphgeom.HtmPixelization`
125 Object to compute spatial indicies from.
127 Returns
128 -------
129 diaSources : `pandas.DataFrame`
130 DiaSources generated across the exposure.
131 """
132 bbox = geom.Box2D(exposure.getBBox())
133 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources)
134 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources)
136 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
137 system=dafBase.DateTime.MJD)
139 wcs = exposure.getWcs()
140 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId()
142 data = []
143 for idx, (x, y) in enumerate(zip(rand_x, rand_y)):
144 coord = wcs.pixelToSky(x, y)
145 htmIdx = 1
146 objId = diaObjectIds[idx % len(diaObjectIds)]
147 # Put together the minimum values for the alert.
148 data.append({"ra": coord.getRa().asDegrees(),
149 "decl": coord.getDec().asDegrees(),
150 "x": x,
151 "y": y,
152 "ccdVisitId": ccdVisitId,
153 "diaObjectId": objId,
154 "ssObjectId": 0,
155 "parentDiaSourceId": 0,
156 "prv_procOrder": 0,
157 "diaSourceId": idx,
158 "pixelId": htmIdx,
159 "midPointTai": midPointTaiMJD + 1.0 * idx,
160 # TODO DM-27170: fix this [0] workaround which gets a
161 # single character representation of the band.
162 "filterName": exposure.getFilter().getCanonicalName()[0],
163 "filterId": 0,
164 "psNdata": 0,
165 "trailNdata": 0,
166 "dipNdata": 0,
167 "flags": 1})
169 return pd.DataFrame(data=data)
172def makeDiaForcedSources(nSources, diaObjectIds, exposure):
173 """Make a test set of DiaSources.
175 Parameters
176 ----------
177 nSources : `int`
178 Number of sources to create.
179 diaObjectIds : `numpy.ndarray`
180 Integer Ids of diaobjects to "associate" with the DiaSources.
181 exposure : `lsst.afw.image.Exposure`
182 Exposure to create sources over.
184 Returns
185 -------
186 diaSources : `pandas.DataFrame`
187 DiaSources generated across the exposure.
188 """
189 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
190 system=dafBase.DateTime.MJD)
192 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId()
194 data = []
195 for idx in range(nSources):
196 objId = diaObjectIds[idx % len(diaObjectIds)]
197 # Put together the minimum values for the alert.
198 data.append({"diaForcedSourceId": idx,
199 "ccdVisitId": ccdVisitId + idx,
200 "diaObjectId": objId,
201 "midPointTai": midPointTaiMJD + 1.0 * idx,
202 # TODO DM-27170: fix this [0] workaround which gets a
203 # single character representation of the band.
204 "filterName": exposure.getFilter().getCanonicalName()[0],
205 "flags": 0})
207 return pd.DataFrame(data=data)
210def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
211 """Run object and source catalogs through the Apdb to get the correct
212 table schemas.
214 Parameters
215 ----------
216 objects : `pandas.DataFrame`
217 Set of test DiaObjects to round trip.
218 sources : `pandas.DataFrame`
219 Set of test DiaSources to round trip.
220 forcedSources : `pandas.DataFrame`
221 Set of test DiaForcedSources to round trip.
222 dateTime : `datetime.datetime`
223 Time for the Apdb.
225 Returns
226 -------
227 objects : `pandas.DataFrame`
228 Round tripped objects.
229 sources : `pandas.DataFrame`
230 Round tripped sources.
231 """
232 tmpFile = tempfile.NamedTemporaryFile()
234 apdbConfig = ApdbConfig()
235 apdbConfig.db_url = "sqlite:///" + tmpFile.name
236 apdbConfig.isolation_level = "READ_UNCOMMITTED"
237 apdbConfig.dia_object_index = "baseline"
238 apdbConfig.dia_object_columns = []
239 apdbConfig.schema_file = _data_file_name(
240 "apdb-schema.yaml", "dax_apdb")
241 apdbConfig.column_map = _data_file_name(
242 "apdb-ap-pipe-afw-map.yaml", "ap_association")
243 apdbConfig.extra_schema_file = _data_file_name(
244 "apdb-ap-pipe-schema-extra.yaml", "ap_association")
246 apdb = Apdb(config=apdbConfig,
247 afw_schemas=dict(DiaObject=make_dia_object_schema(),
248 DiaSource=make_dia_source_schema()))
249 apdb.makeSchema()
251 minId = objects["pixelId"].min()
252 maxId = objects["pixelId"].max()
253 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True).append(objects)
254 diaSources = apdb.getDiaSources(np.unique(objects["diaObjectId"]),
255 dateTime,
256 return_pandas=True).append(sources)
257 diaForcedSources = apdb.getDiaForcedSources(
258 np.unique(objects["diaObjectId"]),
259 dateTime,
260 return_pandas=True).append(forcedSources)
262 apdb.storeDiaSources(diaSources)
263 apdb.storeDiaForcedSources(diaForcedSources)
264 apdb.storeDiaObjects(diaObjects, dateTime)
266 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True)
267 diaSources = apdb.getDiaSources(np.unique(diaObjects["diaObjectId"]),
268 dateTime,
269 return_pandas=True)
270 diaForcedSources = apdb.getDiaForcedSources(
271 np.unique(diaObjects["diaObjectId"]),
272 dateTime,
273 return_pandas=True)
275 diaObjects.set_index("diaObjectId", drop=False, inplace=True)
276 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"],
277 drop=False,
278 inplace=True)
279 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True)
281 return (diaObjects, diaSources, diaForcedSources)
284class TestPackageAlerts(lsst.utils.tests.TestCase):
286 def setUp(self):
287 np.random.seed(1234)
288 self.cutoutSize = 35
289 self.center = lsst.geom.Point2D(50.1, 49.8)
290 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
291 lsst.geom.Extent2I(140, 160))
292 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
293 self.dataset.addSource(100000.0, self.center)
294 exposure, catalog = self.dataset.realize(
295 10.0,
296 self.dataset.makeMinimalSchema(),
297 randomSeed=0)
298 self.exposure = exposure
299 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector
300 self.exposure.setDetector(detector)
302 visit = afwImage.VisitInfo(
303 exposureId=1234,
304 exposureTime=200.,
305 date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
306 dafBase.DateTime.Timescale.TAI))
307 self.exposure.getInfo().setVisitInfo(visit)
309 self.filter_names = ["g"]
310 afwImageUtils.resetFilters()
311 afwImageUtils.defineFilter('g', lambdaEff=487, alias="g.MP9401")
312 self.exposure.setFilter(afwImage.Filter('g'))
314 diaObjects = makeDiaObjects(2, self.exposure)
315 diaSourceHistory = makeDiaSources(10,
316 diaObjects["diaObjectId"],
317 self.exposure)
318 diaForcedSources = makeDiaForcedSources(10,
319 diaObjects["diaObjectId"],
320 self.exposure)
321 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb(
322 diaObjects,
323 diaSourceHistory,
324 diaForcedSources,
325 self.exposure.getInfo().getVisitInfo().getDate().toPython())
326 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True)
327 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True)
328 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True)
329 diaSourceHistory["programId"] = 0
331 self.diaSources = diaSourceHistory.loc[
332 [(0, "g", 8), (1, "g", 9)], :]
333 self.diaSources["bboxSize"] = self.cutoutSize
334 self.diaSourceHistory = diaSourceHistory.drop(labels=[(0, "g", 8),
335 (1, "g", 9)])
337 self.cutoutWcs = wcs.WCS(naxis=2)
338 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]]
339 self.cutoutWcs.wcs.crval = [
340 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(),
341 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()]
342 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
343 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
345 def testCreateBBox(self):
346 """Test the bbox creation
347 """
348 packConfig = PackageAlertsConfig()
349 # Just create a minimum less than the default cutout.
350 packConfig.minCutoutSize = self.cutoutSize - 5
351 packageAlerts = PackageAlertsTask(config=packConfig)
352 bbox = packageAlerts.createDiaSourceBBox(packConfig.minCutoutSize - 5)
353 self.assertTrue(bbox == geom.Extent2I(packConfig.minCutoutSize,
354 packConfig.minCutoutSize))
355 # Test that the cutout size is correct.
356 bbox = packageAlerts.createDiaSourceBBox(self.cutoutSize)
357 self.assertTrue(bbox == geom.Extent2I(self.cutoutSize,
358 self.cutoutSize))
360 def testCreateCcdDataCutout(self):
361 """Test that the data is being extracted into the CCDData cutout
362 correctly.
363 """
364 packageAlerts = PackageAlertsTask()
366 ccdData = packageAlerts.createCcdDataCutout(
367 self.exposure,
368 self.exposure.getWcs().getSkyOrigin(),
369 self.exposure.getPhotoCalib())
370 calibExposure = self.exposure.getPhotoCalib().calibrateImage(
371 self.exposure.getMaskedImage())
373 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd,
374 self.cutoutWcs.wcs.cd)
375 self.assertFloatsAlmostEqual(ccdData.data,
376 calibExposure.getImage().array)
378 def testMakeLocalTransformMatrix(self):
379 """Test that the local WCS approximation is correct.
380 """
381 packageAlerts = PackageAlertsTask()
383 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
384 cutout = self.exposure.getCutout(sphPoint,
385 geom.Extent2I(self.cutoutSize,
386 self.cutoutSize))
387 cd = packageAlerts.makeLocalTransformMatrix(
388 cutout.getWcs(), self.center, sphPoint)
389 self.assertFloatsAlmostEqual(
390 cd,
391 cutout.getWcs().getCdMatrix(),
392 rtol=1e-11,
393 atol=1e-11)
395 def testStreamCcdDataToBytes(self):
396 """Test round tripping an CCDData cutout to bytes and back.
397 """
398 packageAlerts = PackageAlertsTask()
400 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
401 cutout = self.exposure.getCutout(sphPoint,
402 geom.Extent2I(self.cutoutSize,
403 self.cutoutSize))
404 cutoutCcdData = CCDData(
405 data=cutout.getImage().array,
406 wcs=self.cutoutWcs,
407 unit="adu")
409 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData)
410 with io.BytesIO(cutoutBytes) as bytesIO:
411 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
412 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data)
414 def testMakeAlertDict(self):
415 """Test stripping data from the various data products and into a
416 dictionary "alert".
417 """
418 packageAlerts = PackageAlertsTask()
419 alertId = 1234
421 for srcIdx, diaSource in self.diaSources.iterrows():
422 sphPoint = geom.SpherePoint(diaSource["ra"],
423 diaSource["decl"],
424 geom.degrees)
425 cutout = self.exposure.getCutout(sphPoint,
426 geom.Extent2I(self.cutoutSize,
427 self.cutoutSize))
428 ccdCutout = packageAlerts.createCcdDataCutout(
429 cutout, sphPoint, cutout.getPhotoCalib())
430 cutoutBytes = packageAlerts.streamCcdDataToBytes(
431 ccdCutout)
432 objSources = self.diaSourceHistory.loc[srcIdx[0]]
433 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
434 alert = packageAlerts.makeAlertDict(
435 alertId,
436 diaSource,
437 self.diaObjects.loc[srcIdx[0]],
438 objSources,
439 objForcedSources,
440 ccdCutout,
441 ccdCutout)
442 self.assertEqual(len(alert), 9)
444 self.assertEqual(alert["alertId"], alertId)
445 self.assertEqual(alert["diaSource"], diaSource.to_dict())
446 self.assertEqual(alert["cutoutDifference"],
447 cutoutBytes)
448 self.assertEqual(alert["cutoutTemplate"],
449 cutoutBytes)
451 def testRun(self):
452 """Test the run method of package alerts.
453 """
454 packConfig = PackageAlertsConfig()
455 tempdir = tempfile.mkdtemp(prefix='alerts')
456 packConfig.alertWriteLocation = tempdir
457 packageAlerts = PackageAlertsTask(config=packConfig)
459 packageAlerts.run(self.diaSources,
460 self.diaObjects,
461 self.diaSourceHistory,
462 self.diaForcedSources,
463 self.exposure,
464 self.exposure,
465 None)
467 ccdVisitId = self.exposure.getInfo().getVisitInfo().getExposureId()
468 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f:
469 writer_schema, data_stream = \
470 packageAlerts.alertSchema.retrieve_alerts(f)
471 data = list(data_stream)
472 self.assertEqual(len(data), len(self.diaSources))
473 for idx, alert in enumerate(data):
474 for key, value in alert["diaSource"].items():
475 if isinstance(value, float):
476 if np.isnan(self.diaSources.iloc[idx][key]):
477 self.assertTrue(np.isnan(value))
478 else:
479 self.assertAlmostEqual(
480 1 - value / self.diaSources.iloc[idx][key],
481 0.)
482 else:
483 self.assertEqual(value, self.diaSources.iloc[idx][key])
484 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
485 alert["diaSource"]["decl"],
486 geom.degrees)
487 cutout = self.exposure.getCutout(sphPoint,
488 geom.Extent2I(self.cutoutSize,
489 self.cutoutSize))
490 ccdCutout = packageAlerts.createCcdDataCutout(
491 cutout, sphPoint, cutout.getPhotoCalib())
492 self.assertEqual(alert["cutoutDifference"],
493 packageAlerts.streamCcdDataToBytes(ccdCutout))
495 shutil.rmtree(tempdir)
498class MemoryTester(lsst.utils.tests.MemoryTestCase):
499 pass
502def setup_module(module):
503 lsst.utils.tests.init()
506if __name__ == "__main__": 506 ↛ 507line 506 didn't jump to line 507, because the condition on line 506 was never true
507 lsst.utils.tests.init()
508 unittest.main()