Coverage for tests/test_packageAlerts.py : 17%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import io
23import os
24import numpy as np
25import pandas as pd
26import shutil
27import tempfile
28import unittest
30from astropy import wcs
31from astropy.nddata import CCDData
33from lsst.ap.association import (PackageAlertsConfig,
34 PackageAlertsTask,
35 make_dia_source_schema,
36 make_dia_object_schema)
37from lsst.afw.cameraGeom.testUtils import DetectorWrapper
38import lsst.afw.image as afwImage
39import lsst.afw.image.utils as afwImageUtils
40import lsst.daf.base as dafBase
41from lsst.dax.apdb import Apdb, ApdbConfig
42import lsst.geom as geom
43import lsst.meas.base.tests
44from lsst.utils import getPackageDir
45import lsst.utils.tests
48def _data_file_name(basename, module_name):
49 """Return path name of a data file.
51 Parameters
52 ----------
53 basename : `str`
54 Name of the file to add to the path string.
55 module_name : `str`
56 Name of lsst stack package environment variable.
58 Returns
59 -------
60 data_file_path : `str`
61 Full path of the file to load from the "data" directory in a given
62 repository.
63 """
64 return os.path.join(getPackageDir(module_name), "data", basename)
67def makeDiaObjects(nObjects, exposure):
68 """Make a test set of DiaObjects.
70 Parameters
71 ----------
72 nObjects : `int`
73 Number of objects to create.
74 exposure : `lsst.afw.image.Exposure`
75 Exposure to create objects over.
77 Returns
78 -------
79 diaObjects : `pandas.DataFrame`
80 DiaObjects generated across the exposure.
81 """
82 bbox = geom.Box2D(exposure.getBBox())
83 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects)
84 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects)
86 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
87 system=dafBase.DateTime.MJD)
89 wcs = exposure.getWcs()
91 data = []
92 for idx, (x, y) in enumerate(zip(rand_x, rand_y)):
93 coord = wcs.pixelToSky(x, y)
94 htmIdx = 1
95 newObject = {"ra": coord.getRa().asDegrees(),
96 "decl": coord.getDec().asDegrees(),
97 "radecTai": midPointTaiMJD,
98 "diaObjectId": idx,
99 "pixelId": htmIdx,
100 "pmParallaxNdata": 0,
101 "nearbyObj1": 0,
102 "nearbyObj2": 0,
103 "nearbyObj3": 0,
104 "flags": 1,
105 "nDiaSources": 5}
106 for f in ["u", "g", "r", "i", "z", "y"]:
107 newObject["%sPSFluxNdata" % f] = 0
108 data.append(newObject)
110 return pd.DataFrame(data=data)
113def makeDiaSources(nSources, diaObjectIds, exposure):
114 """Make a test set of DiaSources.
116 Parameters
117 ----------
118 nSources : `int`
119 Number of sources to create.
120 diaObjectIds : `numpy.ndarray`
121 Integer Ids of diaobjects to "associate" with the DiaSources.
122 exposure : `lsst.afw.image.Exposure`
123 Exposure to create sources over.
124 pixelator : `lsst.sphgeom.HtmPixelization`
125 Object to compute spatial indicies from.
127 Returns
128 -------
129 diaSources : `pandas.DataFrame`
130 DiaSources generated across the exposure.
131 """
132 bbox = geom.Box2D(exposure.getBBox())
133 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources)
134 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources)
136 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
137 system=dafBase.DateTime.MJD)
139 wcs = exposure.getWcs()
140 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId()
142 data = []
143 for idx, (x, y) in enumerate(zip(rand_x, rand_y)):
144 coord = wcs.pixelToSky(x, y)
145 htmIdx = 1
146 objId = diaObjectIds[idx % len(diaObjectIds)]
147 # Put together the minimum values for the alert.
148 data.append({"ra": coord.getRa().asDegrees(),
149 "decl": coord.getDec().asDegrees(),
150 "x": x,
151 "y": y,
152 "ccdVisitId": ccdVisitId,
153 "diaObjectId": objId,
154 "ssObjectId": 0,
155 "parentDiaSourceId": 0,
156 "prv_procOrder": 0,
157 "diaSourceId": idx,
158 "pixelId": htmIdx,
159 "midPointTai": midPointTaiMJD + 1.0 * idx,
160 # TODO DM-21333: Remove [0] (first character only) workaround
161 "filterName": exposure.getFilter().getCanonicalName()[0],
162 "filterId": 0,
163 "psNdata": 0,
164 "trailNdata": 0,
165 "dipNdata": 0,
166 "flags": 1})
168 return pd.DataFrame(data=data)
171def makeDiaForcedSources(nSources, diaObjectIds, exposure):
172 """Make a test set of DiaSources.
174 Parameters
175 ----------
176 nSources : `int`
177 Number of sources to create.
178 diaObjectIds : `numpy.ndarray`
179 Integer Ids of diaobjects to "associate" with the DiaSources.
180 exposure : `lsst.afw.image.Exposure`
181 Exposure to create sources over.
183 Returns
184 -------
185 diaSources : `pandas.DataFrame`
186 DiaSources generated across the exposure.
187 """
188 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
189 system=dafBase.DateTime.MJD)
191 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId()
193 data = []
194 for idx in range(nSources):
195 objId = diaObjectIds[idx % len(diaObjectIds)]
196 # Put together the minimum values for the alert.
197 data.append({"diaForcedSourceId": idx,
198 "ccdVisitId": ccdVisitId + idx,
199 "diaObjectId": objId,
200 "midPointTai": midPointTaiMJD + 1.0 * idx,
201 # TODO DM-21333: Remove [0] (first character only) workaround
202 "filterName": exposure.getFilter().getCanonicalName()[0],
203 "flags": 0})
205 return pd.DataFrame(data=data)
208def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
209 """Run object and source catalogs through the Apdb to get the correct
210 table schemas.
212 Parameters
213 ----------
214 objects : `pandas.DataFrame`
215 Set of test DiaObjects to round trip.
216 sources : `pandas.DataFrame`
217 Set of test DiaSources to round trip.
218 forcedSources : `pandas.DataFrame`
219 Set of test DiaForcedSources to round trip.
220 dateTime : `datetime.datetime`
221 Time for the Apdb.
223 Returns
224 -------
225 objects : `pandas.DataFrame`
226 Round tripped objects.
227 sources : `pandas.DataFrame`
228 Round tripped sources.
229 """
230 tmpFile = tempfile.NamedTemporaryFile()
232 apdbConfig = ApdbConfig()
233 apdbConfig.db_url = "sqlite:///" + tmpFile.name
234 apdbConfig.isolation_level = "READ_UNCOMMITTED"
235 apdbConfig.dia_object_index = "baseline"
236 apdbConfig.dia_object_columns = []
237 apdbConfig.schema_file = _data_file_name(
238 "apdb-schema.yaml", "dax_apdb")
239 apdbConfig.column_map = _data_file_name(
240 "apdb-ap-pipe-afw-map.yaml", "ap_association")
241 apdbConfig.extra_schema_file = _data_file_name(
242 "apdb-ap-pipe-schema-extra.yaml", "ap_association")
244 apdb = Apdb(config=apdbConfig,
245 afw_schemas=dict(DiaObject=make_dia_object_schema(),
246 DiaSource=make_dia_source_schema()))
247 apdb.makeSchema()
249 minId = objects["pixelId"].min()
250 maxId = objects["pixelId"].max()
251 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True).append(objects)
252 diaSources = apdb.getDiaSources(np.unique(objects["diaObjectId"]),
253 dateTime,
254 return_pandas=True).append(sources)
255 diaForcedSources = apdb.getDiaForcedSources(
256 np.unique(objects["diaObjectId"]),
257 dateTime,
258 return_pandas=True).append(forcedSources)
260 apdb.storeDiaSources(diaSources)
261 apdb.storeDiaForcedSources(diaForcedSources)
262 apdb.storeDiaObjects(diaObjects, dateTime)
264 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True)
265 diaSources = apdb.getDiaSources(np.unique(diaObjects["diaObjectId"]),
266 dateTime,
267 return_pandas=True)
268 diaForcedSources = apdb.getDiaForcedSources(
269 np.unique(diaObjects["diaObjectId"]),
270 dateTime,
271 return_pandas=True)
273 diaObjects.set_index("diaObjectId", drop=False, inplace=True)
274 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"],
275 drop=False,
276 inplace=True)
277 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True)
279 return (diaObjects, diaSources, diaForcedSources)
282class TestPackageAlerts(lsst.utils.tests.TestCase):
284 def setUp(self):
285 np.random.seed(1234)
286 self.cutoutSize = 35
287 self.center = lsst.geom.Point2D(50.1, 49.8)
288 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
289 lsst.geom.Extent2I(140, 160))
290 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
291 self.dataset.addSource(100000.0, self.center)
292 exposure, catalog = self.dataset.realize(
293 10.0,
294 self.dataset.makeMinimalSchema(),
295 randomSeed=0)
296 self.exposure = exposure
297 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector
298 self.exposure.setDetector(detector)
300 visit = afwImage.VisitInfo(
301 exposureId=1234,
302 exposureTime=200.,
303 date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
304 dafBase.DateTime.Timescale.TAI))
305 self.exposure.getInfo().setVisitInfo(visit)
307 self.filter_names = ["g"]
308 afwImageUtils.resetFilters()
309 afwImageUtils.defineFilter('g', lambdaEff=487, alias="g.MP9401")
310 self.exposure.setFilter(afwImage.Filter('g'))
312 diaObjects = makeDiaObjects(2, self.exposure)
313 diaSourceHistory = makeDiaSources(10,
314 diaObjects["diaObjectId"],
315 self.exposure)
316 diaForcedSources = makeDiaForcedSources(10,
317 diaObjects["diaObjectId"],
318 self.exposure)
319 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb(
320 diaObjects,
321 diaSourceHistory,
322 diaForcedSources,
323 self.exposure.getInfo().getVisitInfo().getDate().toPython())
324 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True)
325 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True)
326 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True)
327 diaSourceHistory["programId"] = 0
329 self.diaSources = diaSourceHistory.loc[
330 [(0, "g", 8), (1, "g", 9)], :]
331 self.diaSources["bboxSize"] = self.cutoutSize
332 self.diaSourceHistory = diaSourceHistory.drop(labels=[(0, "g", 8),
333 (1, "g", 9)])
335 self.cutoutWcs = wcs.WCS(naxis=2)
336 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]]
337 self.cutoutWcs.wcs.crval = [
338 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(),
339 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()]
340 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
341 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
343 def testCreateBBox(self):
344 """Test the bbox creation
345 """
346 packConfig = PackageAlertsConfig()
347 # Just create a minimum less than the default cutout.
348 packConfig.minCutoutSize = self.cutoutSize - 5
349 packageAlerts = PackageAlertsTask(config=packConfig)
350 bbox = packageAlerts.createDiaSourceBBox(packConfig.minCutoutSize - 5)
351 self.assertTrue(bbox == geom.Extent2I(packConfig.minCutoutSize,
352 packConfig.minCutoutSize))
353 # Test that the cutout size is correct.
354 bbox = packageAlerts.createDiaSourceBBox(self.cutoutSize)
355 self.assertTrue(bbox == geom.Extent2I(self.cutoutSize,
356 self.cutoutSize))
358 def testCreateCcdDataCutout(self):
359 """Test that the data is being extracted into the CCDData cutout
360 correctly.
361 """
362 packageAlerts = PackageAlertsTask()
364 ccdData = packageAlerts.createCcdDataCutout(
365 self.exposure,
366 self.exposure.getWcs().getSkyOrigin(),
367 self.exposure.getPhotoCalib())
368 calibExposure = self.exposure.getPhotoCalib().calibrateImage(
369 self.exposure.getMaskedImage())
371 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd,
372 self.cutoutWcs.wcs.cd)
373 self.assertFloatsAlmostEqual(ccdData.data,
374 calibExposure.getImage().array)
376 def testMakeLocalTransformMatrix(self):
377 """Test that the local WCS approximation is correct.
378 """
379 packageAlerts = PackageAlertsTask()
381 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
382 cutout = self.exposure.getCutout(sphPoint,
383 geom.Extent2I(self.cutoutSize,
384 self.cutoutSize))
385 cd = packageAlerts.makeLocalTransformMatrix(
386 cutout.getWcs(), self.center, sphPoint)
387 self.assertFloatsAlmostEqual(
388 cd,
389 cutout.getWcs().getCdMatrix(),
390 rtol=1e-11,
391 atol=1e-11)
393 def testStreamCcdDataToBytes(self):
394 """Test round tripping an CCDData cutout to bytes and back.
395 """
396 packageAlerts = PackageAlertsTask()
398 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
399 cutout = self.exposure.getCutout(sphPoint,
400 geom.Extent2I(self.cutoutSize,
401 self.cutoutSize))
402 cutoutCcdData = CCDData(
403 data=cutout.getImage().array,
404 wcs=self.cutoutWcs,
405 unit="adu")
407 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData)
408 with io.BytesIO(cutoutBytes) as bytesIO:
409 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
410 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data)
412 def testMakeAlertDict(self):
413 """Test stripping data from the various data products and into a
414 dictionary "alert".
415 """
416 packageAlerts = PackageAlertsTask()
417 alertId = 1234
419 for srcIdx, diaSource in self.diaSources.iterrows():
420 sphPoint = geom.SpherePoint(diaSource["ra"],
421 diaSource["decl"],
422 geom.degrees)
423 cutout = self.exposure.getCutout(sphPoint,
424 geom.Extent2I(self.cutoutSize,
425 self.cutoutSize))
426 ccdCutout = packageAlerts.createCcdDataCutout(
427 cutout, sphPoint, cutout.getPhotoCalib())
428 cutoutBytes = packageAlerts.streamCcdDataToBytes(
429 ccdCutout)
430 objSources = self.diaSourceHistory.loc[srcIdx[0]]
431 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
432 alert = packageAlerts.makeAlertDict(
433 alertId,
434 diaSource,
435 self.diaObjects.loc[srcIdx[0]],
436 objSources,
437 objForcedSources,
438 ccdCutout,
439 ccdCutout)
440 self.assertEqual(len(alert), 9)
442 self.assertEqual(alert["alertId"], alertId)
443 self.assertEqual(alert["diaSource"], diaSource.to_dict())
444 self.assertEqual(alert["cutoutDifference"],
445 cutoutBytes)
446 self.assertEqual(alert["cutoutTemplate"],
447 cutoutBytes)
449 def testRun(self):
450 """Test the run method of package alerts.
451 """
452 packConfig = PackageAlertsConfig()
453 tempdir = tempfile.mkdtemp(prefix='alerts')
454 packConfig.alertWriteLocation = tempdir
455 packageAlerts = PackageAlertsTask(config=packConfig)
457 packageAlerts.run(self.diaSources,
458 self.diaObjects,
459 self.diaSourceHistory,
460 self.diaForcedSources,
461 self.exposure,
462 self.exposure,
463 None)
465 ccdVisitId = self.exposure.getInfo().getVisitInfo().getExposureId()
466 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f:
467 writer_schema, data = \
468 packageAlerts.alertSchema.retrieve_alerts(f)
469 self.assertEqual(len(data), len(self.diaSources))
470 for idx, alert in enumerate(data):
471 for key, value in alert["diaSource"].items():
472 if isinstance(value, float):
473 if np.isnan(self.diaSources.iloc[idx][key]):
474 self.assertTrue(np.isnan(value))
475 else:
476 self.assertAlmostEqual(
477 1 - value / self.diaSources.iloc[idx][key],
478 0.)
479 else:
480 self.assertEqual(value, self.diaSources.iloc[idx][key])
481 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
482 alert["diaSource"]["decl"],
483 geom.degrees)
484 cutout = self.exposure.getCutout(sphPoint,
485 geom.Extent2I(self.cutoutSize,
486 self.cutoutSize))
487 ccdCutout = packageAlerts.createCcdDataCutout(
488 cutout, sphPoint, cutout.getPhotoCalib())
489 self.assertEqual(alert["cutoutDifference"],
490 packageAlerts.streamCcdDataToBytes(ccdCutout))
492 shutil.rmtree(tempdir)
495class MemoryTester(lsst.utils.tests.MemoryTestCase):
496 pass
499def setup_module(module):
500 lsst.utils.tests.init()
503if __name__ == "__main__": 503 ↛ 504line 503 didn't jump to line 504, because the condition on line 503 was never true
504 lsst.utils.tests.init()
505 unittest.main()