Coverage for tests/test_packageAlerts.py : 18%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
23import numpy as np
24import pandas as pd
25import shutil
26import tempfile
27import unittest
29from lsst.ap.association import (PackageAlertsConfig,
30 PackageAlertsTask,
31 make_dia_source_schema,
32 make_dia_object_schema)
33from lsst.afw.cameraGeom.testUtils import DetectorWrapper
34import lsst.afw.fits as afwFits
35import lsst.afw.image as afwImage
36import lsst.afw.image.utils as afwImageUtils
37import lsst.daf.base as dafBase
38from lsst.dax.apdb import Apdb, ApdbConfig
39import lsst.geom as geom
40import lsst.meas.base.tests
41from lsst.utils import getPackageDir
42import lsst.utils.tests
45def _data_file_name(basename, module_name):
46 """Return path name of a data file.
48 Parameters
49 ----------
50 basename : `str`
51 Name of the file to add to the path string.
52 module_name : `str`
53 Name of lsst stack package environment variable.
55 Returns
56 -------
57 data_file_path : `str`
58 Full path of the file to load from the "data" directory in a given
59 repository.
60 """
61 return os.path.join(getPackageDir(module_name), "data", basename)
64def makeDiaObjects(nObjects, exposure):
65 """Make a test set of DiaObjects.
67 Parameters
68 ----------
69 nObjects : `int`
70 Number of objects to create.
71 exposure : `lsst.afw.image.Exposure`
72 Exposure to create objects over.
74 Returns
75 -------
76 diaObjects : `pandas.DataFrame`
77 DiaObjects generated across the exposure.
78 """
79 bbox = geom.Box2D(exposure.getBBox())
80 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nObjects)
81 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nObjects)
83 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
84 system=dafBase.DateTime.MJD)
86 wcs = exposure.getWcs()
88 data = []
89 for idx, (x, y) in enumerate(zip(rand_x, rand_y)):
90 coord = wcs.pixelToSky(x, y)
91 htmIdx = 1
92 newObject = {"ra": coord.getRa().asDegrees(),
93 "decl": coord.getDec().asDegrees(),
94 "radecTai": midPointTaiMJD,
95 "diaObjectId": idx,
96 "pixelId": htmIdx,
97 "pmParallaxNdata": 0,
98 "nearbyObj1": 0,
99 "nearbyObj2": 0,
100 "nearbyObj3": 0,
101 "flags": 1,
102 "nDiaSources": 5}
103 for f in ["u", "g", "r", "i", "z", "y"]:
104 newObject["%sPSFluxNdata" % f] = 0
105 data.append(newObject)
107 return pd.DataFrame(data=data)
110def makeDiaSources(nSources, diaObjectIds, exposure):
111 """Make a test set of DiaSources.
113 Parameters
114 ----------
115 nSources : `int`
116 Number of sources to create.
117 diaObjectIds : `numpy.ndarray`
118 Integer Ids of diaobjects to "associate" with the DiaSources.
119 exposure : `lsst.afw.image.Exposure`
120 Exposure to create sources over.
121 pixelator : `lsst.sphgeom.HtmPixelization`
122 Object to compute spatial indicies from.
124 Returns
125 -------
126 diaSources : `pandas.DataFrame`
127 DiaSources generated across the exposure.
128 """
129 bbox = geom.Box2D(exposure.getBBox())
130 rand_x = np.random.uniform(bbox.getMinX(), bbox.getMaxX(), size=nSources)
131 rand_y = np.random.uniform(bbox.getMinY(), bbox.getMaxY(), size=nSources)
133 midPointTaiMJD = exposure.getInfo().getVisitInfo().getDate().get(
134 system=dafBase.DateTime.MJD)
136 wcs = exposure.getWcs()
137 ccdVisitId = exposure.getInfo().getVisitInfo().getExposureId()
139 data = []
140 for idx, (x, y) in enumerate(zip(rand_x, rand_y)):
141 coord = wcs.pixelToSky(x, y)
142 htmIdx = 1
143 objId = diaObjectIds[idx % len(diaObjectIds)]
144 # Put together the minimum values for the alert.
145 data.append({"ra": coord.getRa().asDegrees(),
146 "decl": coord.getDec().asDegrees(),
147 "x": x,
148 "y": y,
149 "ccdVisitId": ccdVisitId,
150 "diaObjectId": objId,
151 "ssObjectId": 0,
152 "parentDiaSourceId": 0,
153 "prv_procOrder": 0,
154 "diaSourceId": idx,
155 "pixelId": htmIdx,
156 "midPointTai": midPointTaiMJD + 1.0 * idx,
157 "filterName": exposure.getFilter().getCanonicalName(),
158 "filterId": 0,
159 "psNdata": 0,
160 "trailNdata": 0,
161 "dipNdata": 0,
162 "flags": 1})
164 return pd.DataFrame(data=data)
167def _roundTripThroughApdb(objects, sources, dateTime):
168 """Run object and source catalogs through the Apdb to get the correct
169 table schemas.
171 Parameters
172 ----------
173 objects : `pandas.DataFrame`
174 Set of test DiaObjects to round trip.
175 sources : `pandas.DataFrame`
176 Set of test DiaSources to round trip.
177 dateTime : `datetime.datetime`
178 Time for the Apdb.
180 Returns
181 -------
182 objects : `pandas.DataFrame`
183 Round tripped objects.
184 sources : `pandas.DataFrame`
185 Round tripped sources.
186 """
187 tmpFile = tempfile.NamedTemporaryFile()
189 apdbConfig = ApdbConfig()
190 apdbConfig.db_url = "sqlite:///" + tmpFile.name
191 apdbConfig.isolation_level = "READ_UNCOMMITTED"
192 apdbConfig.dia_object_index = "baseline"
193 apdbConfig.dia_object_columns = []
194 apdbConfig.schema_file = _data_file_name(
195 "apdb-schema.yaml", "dax_apdb")
196 apdbConfig.column_map = _data_file_name(
197 "apdb-ap-pipe-afw-map.yaml", "ap_association")
198 apdbConfig.extra_schema_file = _data_file_name(
199 "apdb-ap-pipe-schema-extra.yaml", "ap_association")
201 apdb = Apdb(config=apdbConfig,
202 afw_schemas=dict(DiaObject=make_dia_object_schema(),
203 DiaSource=make_dia_source_schema()))
204 apdb.makeSchema()
206 minId = objects["pixelId"].min()
207 maxId = objects["pixelId"].max()
208 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True).append(objects)
209 diaSources = apdb.getDiaSources(np.unique(objects["diaObjectId"]),
210 dateTime,
211 return_pandas=True).append(sources)
213 apdb.storeDiaSources(diaSources)
214 apdb.storeDiaObjects(diaObjects, dateTime)
216 diaObjects = apdb.getDiaObjects([[minId, maxId + 1]], return_pandas=True)
217 diaSources = apdb.getDiaSources(np.unique(diaObjects["diaObjectId"]),
218 dateTime,
219 return_pandas=True)
220 diaObjects.set_index("diaObjectId", drop=False, inplace=True)
221 diaSources.set_index(["diaObjectId", "filterName", "diaSourceId"],
222 drop=False,
223 inplace=True)
225 return (diaObjects, diaSources)
228class TestPackageAlerts(unittest.TestCase):
230 def setUp(self):
231 np.random.seed(1234)
232 self.cutoutSize = 35
233 self.center = lsst.geom.Point2D(50.1, 49.8)
234 self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
235 lsst.geom.Extent2I(140, 160))
236 self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
237 self.dataset.addSource(100000.0, self.center)
238 exposure, catalog = self.dataset.realize(
239 10.0,
240 self.dataset.makeMinimalSchema(),
241 randomSeed=0)
242 self.exposure = exposure
243 detector = DetectorWrapper(id=23, bbox=exposure.getBBox()).detector
244 self.exposure.setDetector(detector)
246 visit = afwImage.VisitInfo(
247 exposureId=1234,
248 exposureTime=200.,
249 date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
250 dafBase.DateTime.Timescale.TAI))
251 self.exposure.getInfo().setVisitInfo(visit)
253 self.filter_names = ["g"]
254 afwImageUtils.resetFilters()
255 afwImageUtils.defineFilter('g', lambdaEff=487, alias="g.MP9401")
256 self.exposure.setFilter(afwImage.Filter('g'))
258 diaObjects = makeDiaObjects(2, self.exposure)
259 diaSourceHistory = makeDiaSources(10,
260 diaObjects["diaObjectId"],
261 self.exposure)
262 self.diaObjects, diaSourceHistory = _roundTripThroughApdb(
263 diaObjects,
264 diaSourceHistory,
265 self.exposure.getInfo().getVisitInfo().getDate().toPython())
266 self.diaObjects.replace(to_replace=[None], value=np.nan, inplace=True)
267 diaSourceHistory.replace(to_replace=[None], value=np.nan, inplace=True)
268 diaSourceHistory["programId"] = 0
270 self.diaSources = diaSourceHistory.loc[
271 [(0, "g", 8), (1, "g", 9)], :]
272 self.diaSources["bboxSize"] = self.cutoutSize
273 self.diaSourceHistory = diaSourceHistory.drop(labels=[(0, "g", 8),
274 (1, "g", 9)])
276 def testCreateBBox(self):
277 """Test the bbox creation
278 """
279 packConfig = PackageAlertsConfig()
280 # Just create a minimum less than the default cutout.
281 packConfig.minCutoutSize = self.cutoutSize - 5
282 packageAlerts = PackageAlertsTask(config=packConfig)
283 bbox = packageAlerts.createDiaSourceBBox(packConfig.minCutoutSize - 5)
284 self.assertTrue(bbox == geom.Extent2I(packConfig.minCutoutSize,
285 packConfig.minCutoutSize))
286 # Test that the cutout size is correct.
287 bbox = packageAlerts.createDiaSourceBBox(self.cutoutSize)
288 self.assertTrue(bbox == geom.Extent2I(self.cutoutSize,
289 self.cutoutSize))
291 def testMakeCutoutBytes(self):
292 """Test round tripping an exposure/cutout to bytes and back.
293 """
294 packageAlerts = PackageAlertsTask()
296 sphPoint = self.exposure.getWcs().pixelToSky(self.center)
297 cutout = self.exposure.getCutout(sphPoint,
298 geom.Extent2I(self.cutoutSize,
299 self.cutoutSize))
301 cutoutBytes = packageAlerts.makeCutoutBytes(cutout)
302 tempMemFile = afwFits.MemFileManager(len(cutoutBytes))
303 tempMemFile.setData(cutoutBytes, len(cutoutBytes))
304 cutoutFromBytes = afwImage.ExposureF(tempMemFile)
305 self.assertTrue(
306 np.all(cutout.getImage().array == cutoutFromBytes.getImage().array))
308 def testMakeAlertDict(self):
309 """Test stripping data from the various data products and into a
310 dictionary "alert".
311 """
312 packageAlerts = PackageAlertsTask()
313 alertId = 1234
315 for srcIdx, diaSource in self.diaSources.iterrows():
316 sphPoint = geom.SpherePoint(diaSource["ra"],
317 diaSource["decl"],
318 geom.degrees)
319 cutout = self.exposure.getCutout(sphPoint,
320 geom.Extent2I(self.cutoutSize,
321 self.cutoutSize))
322 cutputBytes = packageAlerts.makeCutoutBytes(cutout)
323 objSources = self.diaSourceHistory.loc[srcIdx[0]]
324 alert = packageAlerts.makeAlertDict(
325 alertId,
326 diaSource,
327 self.diaObjects.loc[srcIdx[0]],
328 objSources,
329 cutout,
330 None)
331 self.assertEqual(len(alert), 9)
333 self.assertEqual(alert["alertId"], alertId)
334 self.assertEqual(alert["diaSource"], diaSource.to_dict())
335 self.assertEqual(alert["cutoutDifference"],
336 cutputBytes)
337 self.assertEqual(alert["cutoutTemplate"],
338 None)
340 def testRun(self):
341 """Test the run method of package alerts.
342 """
343 packConfig = PackageAlertsConfig()
344 tempdir = tempfile.mkdtemp(prefix='alerts')
345 packConfig.alertWriteLocation = tempdir
346 packageAlerts = PackageAlertsTask(config=packConfig)
348 packageAlerts.run(self.diaSources,
349 self.diaObjects,
350 self.diaSourceHistory,
351 self.exposure,
352 None,
353 None)
355 ccdVisitId = self.exposure.getInfo().getVisitInfo().getExposureId()
356 with open(os.path.join(tempdir, f"{ccdVisitId}.avro"), 'rb') as f:
357 writer_schema, data = \
358 packageAlerts.alertSchema.retrieve_alerts(f)
359 self.assertEqual(len(data), len(self.diaSources))
360 for idx, alert in enumerate(data):
361 for key, value in alert["diaSource"].items():
362 if isinstance(value, float):
363 if np.isnan(self.diaSources.iloc[idx][key]):
364 self.assertTrue(np.isnan(value))
365 else:
366 self.assertAlmostEqual(
367 1 - value / self.diaSources.iloc[idx][key],
368 0.)
369 else:
370 self.assertEqual(value, self.diaSources.iloc[idx][key])
371 sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
372 alert["diaSource"]["decl"],
373 geom.degrees)
374 cutout = self.exposure.getCutout(sphPoint,
375 geom.Extent2I(self.cutoutSize,
376 self.cutoutSize))
377 self.assertEqual(alert["cutoutDifference"],
378 packageAlerts.makeCutoutBytes(cutout))
380 shutil.rmtree(tempdir)
383class MemoryTester(lsst.utils.tests.MemoryTestCase):
384 pass
387def setup_module(module):
388 lsst.utils.tests.init()
391if __name__ == "__main__": 391 ↛ 392line 391 didn't jump to line 392, because the condition on line 391 was never true
392 lsst.utils.tests.init()
393 unittest.main()