Coverage for tests/test_packageAlerts.py: 17%
193 statements
« prev ^ index » next coverage.py v7.2.1, created at 2023-03-12 10:45 +0000
« prev ^ index » next coverage.py v7.2.1, created at 2023-03-12 10:45 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import io
23import os
24import numpy as np
25import pandas as pd
26import shutil
27import tempfile
28import unittest
30from astropy import wcs
31from astropy.nddata import CCDData
33from lsst.ap.association import (PackageAlertsConfig,
34 PackageAlertsTask,
35 make_dia_source_schema,
36 make_dia_object_schema)
37from lsst.afw.cameraGeom.testUtils import DetectorWrapper
38import lsst.afw.image as afwImage
39import lsst.daf.base as dafBase
40from lsst.dax.apdb import Apdb, ApdbConfig
41import lsst.geom as geom
42import lsst.meas.base.tests
43from lsst.utils import getPackageDir
44import lsst.utils.tests
47def _data_file_name(basename, module_name):
48 """Return path name of a data file.
50 Parameters
51 ----------
52 basename : `str`
53 Name of the file to add to the path string.
54 module_name : `str`
55 Name of lsst stack package environment variable.
57 Returns
58 -------
59 data_file_path : `str`
60 Full path of the file to load from the "data" directory in a given
61 repository.
62 """
63 return os.path.join(getPackageDir(module_name), "data", basename)
66def makeDiaObjects(nObjects, exposure):
67 """Make a test set of DiaObjects.
69 Parameters
70 ----------
71 nObjects : `int`
72 Number of objects to create.
73 exposure : `lsst.afw.image.Exposure`
74 Exposure to create objects over.
76 Returns
77 -------
78 diaObjects : `pandas.DataFrame`
79 DiaObjects generated across the exposure.
80 """
81 bbox = geom.Box2D(exposure.getBBox())
82 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects)
83 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects)
85 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
86 system=dafBase.DateTime.MJD)
88 wcs = exposure.getWcs()
90 data = []
91 for idx, (x, y) in enumerate(zip(rand_x, rand_y)):
92 coord = wcs.pixelToSky(x, y)
93 htmIdx = 1
94 newObject = {"ra": coord.getRa().asDegrees(),
95 "decl": coord.getDec().asDegrees(),
96 "radecTai": midPointTaiMJD,
97 "diaObjectId": idx,
98 "pixelId": htmIdx,
99 "pmParallaxNdata": 0,
100 "nearbyObj1": 0,
101 "nearbyObj2": 0,
102 "nearbyObj3": 0,
103 "flags": 1,
104 "nDiaSources": 5}
105 for f in ["u", "g", "r", "i", "z", "y"]:
106 newObject["%sPSFluxNdata" % f] = 0
107 data.append(newObject)
109 return pd.DataFrame(data=data)
112def makeDiaSources(nSources, diaObjectIds, exposure):
113 """Make a test set of DiaSources.
115 Parameters
116 ----------
117 nSources : `int`
118 Number of sources to create.
119 diaObjectIds : `numpy.ndarray`
120 Integer Ids of diaobjects to "associate" with the DiaSources.
121 exposure : `lsst.afw.image.Exposure`
122 Exposure to create sources over.
123 pixelator : `lsst.sphgeom.HtmPixelization`
124 Object to compute spatial indicies from.
126 Returns
127 -------
128 diaSources : `pandas.DataFrame`
129 DiaSources generated across the exposure.
130 """
131 bbox = geom.Box2D(exposure.getBBox())
132 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources)
133 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources)
135 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
136 system=dafBase.DateTime.MJD)
138 wcs = exposure.getWcs()
139 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId()
141 data = []
142 for idx, (x, y) in enumerate(zip(rand_x, rand_y)):
143 coord = wcs.pixelToSky(x, y)
144 htmIdx = 1
145 objId = diaObjectIds[idx % len(diaObjectIds)]
146 # Put together the minimum values for the alert.
147 data.append({"ra": coord.getRa().asDegrees(),
148 "decl": coord.getDec().asDegrees(),
149 "x": x,
150 "y": y,
151 "ccdVisitId": ccdVisitId,
152 "diaObjectId": objId,
153 "ssObjectId": 0,
154 "parentDiaSourceId": 0,
155 "prv_procOrder": 0,
156 "diaSourceId": idx,
157 "pixelId": htmIdx,
158 "midPointTai": midPointTaiMJD + 1.0 * idx,
159 "filterName": exposure.getFilterLabel().bandLabel,
160 "psNdata": 0,
161 "trailNdata": 0,
162 "dipNdata": 0,
163 "flags": 1})
165 return pd.DataFrame(data=data)
168def makeDiaForcedSources(nSources, diaObjectIds, exposure):
169 """Make a test set of DiaSources.
171 Parameters
172 ----------
173 nSources : `int`
174 Number of sources to create.
175 diaObjectIds : `numpy.ndarray`
176 Integer Ids of diaobjects to "associate" with the DiaSources.
177 exposure : `lsst.afw.image.Exposure`
178 Exposure to create sources over.
180 Returns
181 -------
182 diaSources : `pandas.DataFrame`
183 DiaSources generated across the exposure.
184 """
185 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
186 system=dafBase.DateTime.MJD)
188 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId()
190 data = []
191 for idx in range(nSources):
192 objId = diaObjectIds[idx % len(diaObjectIds)]
193 # Put together the minimum values for the alert.
194 data.append({"diaForcedSourceId": idx,
195 "ccdVisitId": ccdVisitId + idx,
196 "diaObjectId": objId,
197 "midPointTai": midPointTaiMJD + 1.0 * idx,
198 "filterName": exposure.getFilterLabel().bandLabel,
199 "flags": 0})
201 return pd.DataFrame(data=data)
204def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
205 """Run object and source catalogs through the Apdb to get the correct
206 table schemas.
208 Parameters
209 ----------
210 objects : `pandas.DataFrame`
211 Set of test DiaObjects to round trip.
212 sources : `pandas.DataFrame`
213 Set of test DiaSources to round trip.
214 forcedSources : `pandas.DataFrame`
215 Set of test DiaForcedSources to round trip.
216 dateTime : `datetime.datetime`
217 Time for the Apdb.
219 Returns
220 -------
221 objects : `pandas.DataFrame`
222 Round tripped objects.
223 sources : `pandas.DataFrame`
224 Round tripped sources.
225 """
226 tmpFile = tempfile.NamedTemporaryFile()
228 apdbConfig = ApdbConfig()
229 apdbConfig.db_url = "sqlite:///" + tmpFile.name
230 apdbConfig.isolation_level = "READ_UNCOMMITTED"
231 apdbConfig.dia_object_index = "baseline"
232 apdbConfig.dia_object_columns = []
233 apdbConfig.schema_file = _data_file_name(
234 "apdb-schema.yaml", "dax_apdb")
235 apdbConfig.column_map = _data_file_name(
236 "apdb-ap-pipe-afw-map.yaml", "ap_association")
237 apdbConfig.extra_schema_file = _data_file_name(
238 "apdb-ap-pipe-schema-extra.yaml", "ap_association")
240 apdb = Apdb(config=apdbConfig,
241 afw_schemas=dict(DiaObject=make_dia_object_schema(),
242 DiaSource=make_dia_source_schema()))
243 apdb.makeSchema()
245 minId = objects["pixelId"].min()
246 maxId = objects["pixelId"].max()
247 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True).append(objects)
248 diaSources = apdb.getDiaSources(np.unique(objects["diaObjectId"]),
249 dateTime,
250 return_pandas=True).append(sources)
251 diaForcedSources = apdb.getDiaForcedSources(
252 np.unique(objects["diaObjectId"]),
253 dateTime,
254 return_pandas=True).append(forcedSources)
256 apdb.storeDiaSources(diaSources)
257 apdb.storeDiaForcedSources(diaForcedSources)
258 apdb.storeDiaObjects(diaObjects, dateTime)
260 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True)
261 diaSources = apdb.getDiaSources(np.unique(diaObjects["diaObjectId"]),
262 dateTime,
263 return_pandas=True)
264 diaForcedSources = apdb.getDiaForcedSources(
265 np.unique(diaObjects["diaObjectId"]),
266 dateTime,
267 return_pandas=True)
269 diaObjects.set_index("diaObjectId", drop=False, inplace=True)
270 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"],
271 drop=False,
272 inplace=True)
273 diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True)
275 return (diaObjects, diaSources, diaForcedSources)
278class TestPackageAlerts(lsst.utils.tests.TestCase):
280 def setUp(self):
281 np.random.seed(1234)
282 self.cutoutSize = 35
283 self.center = lsst.geom.Point2D(50.1, 49.8)
284 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
285 lsst.geom.Extent2I(140, 160))
286 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
287 self.dataset.addSource(100000.0, self.center)
288 exposure, catalog = self.dataset.realize(
289 10.0,
290 self.dataset.makeMinimalSchema(),
291 randomSeed=0)
292 self.exposure = exposure
293 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector
294 self.exposure.setDetector(detector)
296 visit = afwImage.VisitInfo(
297 exposureId=1234,
298 exposureTime=200.,
299 date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
300 dafBase.DateTime.Timescale.TAI))
301 self.exposure.getInfo().setVisitInfo(visit)
303 self.exposure.setFilterLabel(afwImage.FilterLabel(band='g', physical="g.MP9401"))
305 diaObjects = makeDiaObjects(2, self.exposure)
306 diaSourceHistory = makeDiaSources(10,
307 diaObjects["diaObjectId"],
308 self.exposure)
309 diaForcedSources = makeDiaForcedSources(10,
310 diaObjects["diaObjectId"],
311 self.exposure)
312 self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb(
313 diaObjects,
314 diaSourceHistory,
315 diaForcedSources,
316 self.exposure.getInfo().getVisitInfo().getDate().toPython())
317 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True)
318 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True)
319 self.diaForcedSources.replace(to_replace=[None], value=np.nan, inplace=True)
320 diaSourceHistory["programId"] = 0
322 self.diaSources = diaSourceHistory.loc[
323 [(0, "g", 8), (1, "g", 9)], :]
324 self.diaSources["bboxSize"] = self.cutoutSize
325 self.diaSourceHistory = diaSourceHistory.drop(labels=[(0, "g", 8),
326 (1, "g", 9)])
328 self.cutoutWcs = wcs.WCS(naxis=2)
329 self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]]
330 self.cutoutWcs.wcs.crval = [
331 self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(),
332 self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()]
333 self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
334 self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
336 def testCreateExtent(self):
337 """Test the extent creation for the cutout bbox.
338 """
339 packConfig = PackageAlertsConfig()
340 # Just create a minimum less than the default cutout.
341 packConfig.minCutoutSize = self.cutoutSize - 5
342 packageAlerts = PackageAlertsTask(config=packConfig)
343 extent = packageAlerts.createDiaSourceExtent(
344 packConfig.minCutoutSize - 5)
345 self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize,
346 packConfig.minCutoutSize))
347 # Test that the cutout size is correct.
348 extent = packageAlerts.createDiaSourceExtent(self.cutoutSize)
349 self.assertTrue(extent == geom.Extent2I(self.cutoutSize,
350 self.cutoutSize))
352 def testCreateCcdDataCutout(self):
353 """Test that the data is being extracted into the CCDData cutout
354 correctly.
355 """
356 packageAlerts = PackageAlertsTask()
358 diaSrcId = 1234
359 ccdData = packageAlerts.createCcdDataCutout(
360 self.exposure,
361 self.exposure.getWcs().getSkyOrigin(),
362 self.exposure.getBBox().getDimensions(),
363 self.exposure.getPhotoCalib(),
364 diaSrcId)
365 calibExposure = self.exposure.getPhotoCalib().calibrateImage(
366 self.exposure.getMaskedImage())
368 self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd,
369 self.cutoutWcs.wcs.cd)
370 self.assertFloatsAlmostEqual(ccdData.data,
371 calibExposure.getImage().array)
373 ccdData = packageAlerts.createCcdDataCutout(
374 self.exposure,
375 geom.SpherePoint(0, 0, geom.degrees),
376 self.exposure.getBBox().getDimensions(),
377 self.exposure.getPhotoCalib(),
378 diaSrcId)
379 self.assertTrue(ccdData is None)
381 def testMakeLocalTransformMatrix(self):
382 """Test that the local WCS approximation is correct.
383 """
384 packageAlerts = PackageAlertsTask()
386 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
387 cutout = self.exposure.getCutout(sphPoint,
388 geom.Extent2I(self.cutoutSize,
389 self.cutoutSize))
390 cd = packageAlerts.makeLocalTransformMatrix(
391 cutout.getWcs(), self.center, sphPoint)
392 self.assertFloatsAlmostEqual(
393 cd,
394 cutout.getWcs().getCdMatrix(),
395 rtol=1e-11,
396 atol=1e-11)
398 def testStreamCcdDataToBytes(self):
399 """Test round tripping an CCDData cutout to bytes and back.
400 """
401 packageAlerts = PackageAlertsTask()
403 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
404 cutout = self.exposure.getCutout(sphPoint,
405 geom.Extent2I(self.cutoutSize,
406 self.cutoutSize))
407 cutoutCcdData = CCDData(
408 data=cutout.getImage().array,
409 wcs=self.cutoutWcs,
410 unit="adu")
412 cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData)
413 with io.BytesIO(cutoutBytes) as bytesIO:
414 cutoutFromBytes = CCDData.read(bytesIO, format="fits")
415 self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data)
417 def testMakeAlertDict(self):
418 """Test stripping data from the various data products and into a
419 dictionary "alert".
420 """
421 packageAlerts = PackageAlertsTask()
422 alertId = 1234
424 for srcIdx, diaSource in self.diaSources.iterrows():
425 sphPoint = geom.SpherePoint(diaSource["ra"],
426 diaSource["decl"],
427 geom.degrees)
428 cutout = self.exposure.getCutout(sphPoint,
429 geom.Extent2I(self.cutoutSize,
430 self.cutoutSize))
431 ccdCutout = packageAlerts.createCcdDataCutout(
432 cutout,
433 sphPoint,
434 geom.Extent2I(self.cutoutSize, self.cutoutSize),
435 cutout.getPhotoCalib(),
436 1234)
437 cutoutBytes = packageAlerts.streamCcdDataToBytes(
438 ccdCutout)
439 objSources = self.diaSourceHistory.loc[srcIdx[0]]
440 objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
441 alert = packageAlerts.makeAlertDict(
442 alertId,
443 diaSource,
444 self.diaObjects.loc[srcIdx[0]],
445 objSources,
446 objForcedSources,
447 ccdCutout,
448 ccdCutout)
449 self.assertEqual(len(alert), 9)
451 self.assertEqual(alert["alertId"], alertId)
452 self.assertEqual(alert["diaSource"], diaSource.to_dict())
453 self.assertEqual(alert["cutoutDifference"],
454 cutoutBytes)
455 self.assertEqual(alert["cutoutTemplate"],
456 cutoutBytes)
458 def testRun(self):
459 """Test the run method of package alerts.
460 """
461 packConfig = PackageAlertsConfig()
462 tempdir = tempfile.mkdtemp(prefix='alerts')
463 packConfig.alertWriteLocation = tempdir
464 packageAlerts = PackageAlertsTask(config=packConfig)
466 packageAlerts.run(self.diaSources,
467 self.diaObjects,
468 self.diaSourceHistory,
469 self.diaForcedSources,
470 self.exposure,
471 self.exposure,
472 None)
474 ccdVisitId = self.exposure.getInfo().getVisitInfo().getExposureId()
475 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f:
476 writer_schema, data_stream = \
477 packageAlerts.alertSchema.retrieve_alerts(f)
478 data = list(data_stream)
479 self.assertEqual(len(data), len(self.diaSources))
480 for idx, alert in enumerate(data):
481 for key, value in alert["diaSource"].items():
482 if isinstance(value, float):
483 if np.isnan(self.diaSources.iloc[idx][key]):
484 self.assertTrue(np.isnan(value))
485 else:
486 self.assertAlmostEqual(
487 1 - value / self.diaSources.iloc[idx][key],
488 0.)
489 else:
490 self.assertEqual(value, self.diaSources.iloc[idx][key])
491 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
492 alert["diaSource"]["decl"],
493 geom.degrees)
494 cutout = self.exposure.getCutout(sphPoint,
495 geom.Extent2I(self.cutoutSize,
496 self.cutoutSize))
497 ccdCutout = packageAlerts.createCcdDataCutout(
498 cutout,
499 sphPoint,
500 geom.Extent2I(self.cutoutSize, self.cutoutSize),
501 cutout.getPhotoCalib(),
502 1234)
503 self.assertEqual(alert["cutoutDifference"],
504 packageAlerts.streamCcdDataToBytes(ccdCutout))
506 shutil.rmtree(tempdir)
509class MemoryTester(lsst.utils.tests.MemoryTestCase):
510 pass
513def setup_module(module):
514 lsst.utils.tests.init()
517if __name__ == "__main__": 517 ↛ 518line 517 didn't jump to line 518, because the condition on line 517 was never true
518 lsst.utils.tests.init()
519 unittest.main()