Coverage for tests/test_association_task.py : 10%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import numpy as np
23import pandas as pd
24import unittest
26from lsst.afw.cameraGeom.testUtils import DetectorWrapper
27import lsst.afw.geom as afwGeom
28import lsst.afw.image as afwImage
29import lsst.afw.table as afwTable
30import lsst.daf.base as dafBase
31import lsst.geom as geom
32import lsst.sphgeom as sphgeom
33import lsst.utils.tests
35from lsst.ap.association import \
36 AssociationTask, \
37 make_dia_source_schema, \
38 make_dia_object_schema
41def create_test_points(point_locs_deg,
42 wcs=None,
43 start_id=0,
44 schema=None,
45 scatter_arcsec=1.0,
46 indexer_ids=None,
47 associated_ids=None):
48 """Create dummy DIASources or DIAObjects for use in our tests.
49 Parameters
50 ----------
51 point_locs_deg : array-like (N, 2) of `float`s
52 Positions of the test points to create in RA, DEC.
53 wcs : `lsst.afw.geom.SkyWcs`
54 Wcs to convert RA/Dec to x/y if provided.
55 start_id : `int`
56 Unique id of the first object to create. The remaining sources are
57 incremented by one from the first id.
58 schema : `lsst.afw.table.Schema`
59 Schema of the objects to create. Defaults to the DIASource schema.
60 scatter_arcsec : `float`
61 Scatter to add to the position of each DIASource.
62 indexer_ids : `list` of `ints`s
63 Id numbers of pixelization indexer to store. Must be the same length
64 as the first dimension of point_locs_deg.
65 associated_ids : `list` of `ints`s
66 Id numbers of associated DIAObjects to store. Must be the same length
67 as the first dimension of point_locs_deg.
68 Returns
69 -------
70 test_points : `lsst.afw.table.SourceCatalog`
71 Catalog of points to test.
72 """
73 if schema is None:
74 schema = make_dia_source_schema()
75 sources = afwTable.SourceCatalog(schema)
77 for src_idx, (ra, dec,) in enumerate(point_locs_deg):
78 src = sources.addNew()
79 src['id'] = src_idx + start_id
80 coord = geom.SpherePoint(ra, dec, geom.degrees)
81 if scatter_arcsec > 0.0:
82 coord = coord.offset(
83 np.random.rand() * 360 * geom.degrees,
84 np.random.rand() * scatter_arcsec * geom.arcseconds)
85 if indexer_ids is not None:
86 src['pixelId'] = indexer_ids[src_idx]
87 if associated_ids is not None:
88 src['diaObjectId'] = associated_ids[src_idx]
89 src.setCoord(coord)
91 if wcs is not None:
92 xyCentroid = wcs.skykToPixel(coord)
93 src.set("x", xyCentroid.getX())
94 src.set("y", xyCentroid.getY())
96 return sources
99def create_test_points_pandas(point_locs_deg,
100 wcs=None,
101 start_id=0,
102 schema=None,
103 scatter_arcsec=1.0,
104 indexer_ids=None,
105 associated_ids=None):
106 """Create dummy DIASources or DIAObjects for use in our tests.
107 Parameters
108 ----------
109 point_locs_deg : array-like (N, 2) of `float`s
110 Positions of the test points to create in RA, DEC.
111 wcs : `lsst.afw.geom.SkyWcs`
112 Wcs to convert RA/Dec to x/y if provided.
113 start_id : `int`
114 Unique id of the first object to create. The remaining sources are
115 incremented by one from the first id.
116 schema : `lsst.afw.table.Schema`
117 Schema of the objects to create. Defaults to the DIASource schema.
118 scatter_arcsec : `float`
119 Scatter to add to the position of each DIASource.
120 indexer_ids : `list` of `ints`s
121 Id numbers of pixelization indexer to store. Must be the same length
122 as the first dimension of point_locs_deg.
123 associated_ids : `list` of `ints`s
124 Id numbers of associated DIAObjects to store. Must be the same length
125 as the first dimension of point_locs_deg.
126 Returns
127 -------
128 test_points : `pandas.DataFrame`
129 Catalog of points to test.
130 """
131 if schema is None:
132 schema = make_dia_source_schema()
133 sources = afwTable.SourceCatalog(schema)
135 for src_idx, (ra, dec,) in enumerate(point_locs_deg):
136 src = sources.addNew()
137 src['id'] = src_idx + start_id
138 coord = geom.SpherePoint(ra, dec, geom.degrees)
139 if scatter_arcsec > 0.0:
140 coord = coord.offset(
141 np.random.rand() * 360 * geom.degrees,
142 np.random.rand() * scatter_arcsec * geom.arcseconds)
143 if indexer_ids is not None:
144 src['pixelId'] = indexer_ids[src_idx]
145 if associated_ids is not None:
146 src['diaObjectId'] = associated_ids[src_idx]
147 src.setCoord(coord)
149 if wcs is not None:
150 xyCentroid = wcs.skykToPixel(coord)
151 src.set("x", xyCentroid.getX())
152 src.set("y", xyCentroid.getY())
154 sources = sources.asAstropy().to_pandas()
156 return sources
159class TestAssociationTask(unittest.TestCase):
161 def setUp(self):
162 """Create a sqlite3 database with default tables and schemas.
163 """
164 self.filter_names = ["u", "g", "r", "i", "z"]
165 self.dia_object_schema = make_dia_object_schema()
167 # metadata taken from CFHT data
168 # v695856-e0/v695856-e0-c000-a00.sci_img.fits
170 self.metadata = dafBase.PropertySet()
172 self.metadata.set("SIMPLE", "T")
173 self.metadata.set("BITPIX", -32)
174 self.metadata.set("NAXIS", 2)
175 self.metadata.set("NAXIS1", 1024)
176 self.metadata.set("NAXIS2", 1153)
177 self.metadata.set("RADECSYS", 'FK5')
178 self.metadata.set("EQUINOX", 2000.)
180 self.metadata.setDouble("CRVAL1", 215.604025685476)
181 self.metadata.setDouble("CRVAL2", 53.1595451514076)
182 self.metadata.setDouble("CRPIX1", 1109.99981456774)
183 self.metadata.setDouble("CRPIX2", 560.018167811613)
184 self.metadata.set("CTYPE1", 'RA---SIN')
185 self.metadata.set("CTYPE2", 'DEC--SIN')
187 self.metadata.setDouble("CD1_1", 5.10808596133527E-05)
188 self.metadata.setDouble("CD1_2", 1.85579539217196E-07)
189 self.metadata.setDouble("CD2_2", -5.10281493481982E-05)
190 self.metadata.setDouble("CD2_1", -8.27440751733828E-07)
192 self.wcs = afwGeom.makeSkyWcs(self.metadata)
193 self.exposure = afwImage.makeExposure(
194 afwImage.makeMaskedImageFromArrays(np.ones((1024, 1153))),
195 self.wcs)
196 detector = DetectorWrapper(id=23, bbox=self.exposure.getBBox()).detector
197 visit = afwImage.VisitInfo(
198 exposureId=1234,
199 exposureTime=200.,
200 date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
201 dafBase.DateTime.Timescale.TAI))
202 self.exposure.setDetector(detector)
203 self.exposure.getInfo().setVisitInfo(visit)
204 self.exposure.setFilterLabel(afwImage.FilterLabel(band='g'))
205 self.flux0 = 10000
206 self.flux0_err = 100
207 self.exposure.setPhotoCalib(
208 afwImage.PhotoCalib(self.flux0, self.flux0_err))
210 bbox = geom.Box2D(self.exposure.getBBox())
211 wcs = self.exposure.getWcs()
213 self.pixelator = sphgeom.HtmPixelization(20)
214 region = sphgeom.ConvexPolygon([wcs.pixelToSky(pp).getVector()
215 for pp in bbox.getCorners()])
217 indices = self.pixelator.envelope(region, 64)
218 # Index types must be cast to int to work with dax_apdb.
219 self.index_ranges = indices.ranges()
221 def tearDown(self):
222 """Delete the database after we are done with it.
223 """
224 del self.metadata
225 del self.wcs
226 del self.exposure
228 def test_run(self):
229 """Test the run method with a database that already exists and
230 contains DIAObjects and Sources.
231 """
232 dia_objects = self._run_association_and_retrieve_objects(True)
233 not_updated_idx = 0
234 updated_idx_start = 1
235 new_idx_start = 5
236 total_expected_dia_objects = 10
237 self.assertEqual(len(dia_objects), total_expected_dia_objects)
239 # Test to make sure the number of DIAObjects have been properly
240 # associated within the db.
241 for obj_idx, (df_idx, dia_object) in enumerate(dia_objects.iterrows()):
242 if df_idx == not_updated_idx:
243 # Test the DIAObject we expect to not be associated with any
244 # new DIASources.
245 self.assertEqual(dia_object['gPSFluxNdata'], 1)
246 self.assertEqual(dia_object['rPSFluxNdata'], 1)
247 self.assertEqual(dia_object['nDiaSources'], 2)
248 self.assertEqual(df_idx, obj_idx)
249 elif updated_idx_start <= df_idx < new_idx_start:
250 # Test that associating to the existing DIAObjects went
251 # as planned and test that the IDs of the newly associated
252 # DIASources is correct.
253 self.assertEqual(dia_object['gPSFluxNdata'], 2)
254 self.assertEqual(dia_object['rPSFluxNdata'], 1)
255 self.assertEqual(dia_object['nDiaSources'], 3)
256 self.assertEqual(df_idx, obj_idx)
257 else:
258 self.assertEqual(dia_object['gPSFluxNdata'], 1)
259 self.assertEqual(dia_object['nDiaSources'], 1)
260 self.assertEqual(df_idx, obj_idx + 4 + 5)
262 def test_run_no_existing_objects(self):
263 """Test the run method with a completely empty database.
264 """
265 dia_objects = self._run_association_and_retrieve_objects(False)
266 total_expected_dia_objects = 9
267 self.assertEqual(len(dia_objects),
268 total_expected_dia_objects)
269 for obj_idx, (df_idx, output_dia_object) in enumerate(dia_objects.iterrows()):
270 self.assertEqual(output_dia_object['gPSFluxNdata'], 1)
271 self.assertEqual(df_idx, obj_idx + 10)
273 def _run_association_and_retrieve_objects(self, create_objects=False):
274 """Convenience method for testing the Association run method.
276 Parameters
277 ----------
278 create_objects : `bool`
279 Boolean specifying if seed DIAObjects and DIASources should be
280 inserted into the database before association.
282 Return
283 ------
284 dia_objects : `lsst.afw.table.SourceCatalog`
285 Final set of DIAObjects to be tested.
286 """
287 if create_objects:
288 diaObjects, diaSourceHistory = \
289 self._create_dia_objects_and_sources()
290 else:
291 diaObjects = pd.DataFrame(columns=["diaObjectId"])
292 diaSourceHistory = pd.DataFrame(columns=["diaObjectId",
293 "filterName",
294 "diaSourceId"])
295 diaObjects.set_index("diaObjectId",
296 inplace=True,
297 drop=False)
298 diaSourceHistory.set_index(["diaObjectId",
299 "filterName",
300 "diaSourceId"],
301 inplace=True,
302 drop=False)
304 source_centers = [
305 [self.wcs.pixelToSky(idx, idx).getRa().asDegrees(),
306 self.wcs.pixelToSky(idx, idx).getDec().asDegrees()]
307 for idx in np.linspace(1, 1000, 10)[1:]]
308 dia_sources = create_test_points(
309 point_locs_deg=source_centers,
310 start_id=10,
311 scatter_arcsec=-1)
312 for dia_source in dia_sources:
313 self._set_source_values(
314 dia_source=dia_source,
315 flux=10000,
316 fluxErr=100,
317 filterName=self.exposure.getFilterLabel().bandLabel,
318 ccdVisitId=self.exposure.getInfo().getVisitInfo().getExposureId(),
319 midPointTai=self.exposure.getInfo().getVisitInfo().getDate().get(system=dafBase.DateTime.MJD))
321 assoc_task = AssociationTask()
323 diaSources = dia_sources.asAstropy().to_pandas()
324 diaSources.rename(columns={"coord_ra": "ra",
325 "coord_dec": "decl",
326 "id": "diaSourceId",
327 "parent": "parentDiaSourceId"},
328 inplace=True)
329 diaSources["ra"] = np.degrees(diaSources["ra"])
330 diaSources["decl"] = np.degrees(diaSources["decl"])
332 if len(diaObjects) == 0:
333 diaSourceHistory = pd.DataFrame(columns=["diaObjectId",
334 "filterName",
335 "diaSourceId"])
336 diaSourceHistory.set_index(
337 ["diaObjectId", "filterName", "diaSourceId"],
338 drop=False,
339 inplace=True)
341 results = assoc_task.run(diaSources,
342 diaObjects,
343 diaSourceHistory)
344 return results.diaObjects
346 def _set_source_values(self, dia_source, flux, fluxErr, filterName,
347 ccdVisitId, midPointTai):
348 """Set fluxes and visit info for DiaSources.
350 Parameters
351 ----------
352 dia_source : `lsst.afw.table.SourceRecord`
353 SourceRecord object to edit.
354 flux : `double`
355 Flux of DiaSource
356 fluxErr : `double`
357 Flux error of DiaSource
358 filterName : `string`
359 Name of filter for flux.
360 ccdVisitId : `int`
361 Integer id of this ccd/visit.
362 midPointTai : `double`
363 Time of observation
364 """
365 dia_source['ccdVisitId'] = ccdVisitId
366 dia_source["midPointTai"] = midPointTai
367 dia_source["psFlux"] = flux / self.flux0
368 dia_source["psFluxErr"] = np.sqrt(
369 (fluxErr / self.flux0) ** 2
370 + (flux * self.flux0_err / self.flux0 ** 2) ** 2)
371 dia_source["apFlux"] = flux / self.flux0
372 dia_source["apFluxErr"] = np.sqrt(
373 (fluxErr / self.flux0) ** 2
374 + (flux * self.flux0_err / self.flux0 ** 2) ** 2)
375 dia_source["totFlux"] = flux / self.flux0
376 dia_source["totFluxErr"] = np.sqrt(
377 (fluxErr / self.flux0) ** 2
378 + (flux * self.flux0_err / self.flux0 ** 2) ** 2)
379 dia_source["filterName"] = filterName
380 dia_source["x"] = 0.
381 dia_source["y"] = 0.
383 def _create_dia_objects_and_sources(self):
384 """Method for storing a set of test DIAObjects and sources into
385 the L1 database.
386 """
388 # This should create a DB of 5 DIAObjects with 2 DIASources associated
389 # to them. The DIASources are "observed" in g and r.
391 # Create DIObjects, give them fluxes, and store them
392 n_objects = 5
393 object_centers = np.array([
394 [self.wcs.pixelToSky(idx, idx).getRa().asDegrees(),
395 self.wcs.pixelToSky(idx, idx).getDec().asDegrees()]
396 for idx in np.linspace(1, 1000, 10)])
397 dia_objects = create_test_points(
398 point_locs_deg=object_centers[:n_objects],
399 start_id=0,
400 schema=self.dia_object_schema,
401 scatter_arcsec=-1,)
402 # Set the DIAObject fluxes and number of associated sources.
403 for dia_object in dia_objects:
404 dia_object["nDiaSources"] = 2
405 for filter_name in self.filter_names:
406 sphPoint = geom.SpherePoint(dia_object.getCoord())
407 htmIndex = self.pixelator.index(sphPoint.getVector())
408 dia_object["pixelId"] = htmIndex
409 dia_object['%sPSFluxMean' % filter_name] = 1
410 dia_object['%sPSFluxMeanErr' % filter_name] = 1
411 dia_object['%sPSFluxSigma' % filter_name] = 1
412 dia_object['%sPSFluxNdata' % filter_name] = 1
413 dia_objects = dia_objects.asAstropy().to_pandas()
414 dia_objects.rename(columns={"coord_ra": "ra",
415 "coord_dec": "decl",
416 "id": "diaObjectId"},
417 inplace=True)
418 dia_objects["ra"] = np.degrees(dia_objects["ra"])
419 dia_objects["decl"] = np.degrees(dia_objects["decl"])
421 dateTime = dafBase.DateTime("2014-05-13T16:00:00.000000000",
422 dafBase.DateTime.Timescale.TAI)
424 # Create DIASources, update their ccdVisitId and fluxes, and store
425 # them.
426 dia_sources = create_test_points(
427 point_locs_deg=np.concatenate(
428 [object_centers[:n_objects], object_centers[:n_objects]]),
429 start_id=0,
430 scatter_arcsec=-1,
431 associated_ids=[0, 1, 2, 3, 4,
432 0, 1, 2, 3, 4])
433 for src_idx, dia_source in enumerate(dia_sources):
434 if src_idx < n_objects:
435 self._set_source_values(
436 dia_source=dia_source,
437 flux=10000,
438 fluxErr=100,
439 filterName='g',
440 ccdVisitId=1232,
441 midPointTai=dateTime.get(system=dafBase.DateTime.MJD))
442 else:
443 self._set_source_values(
444 dia_source=dia_source,
445 flux=10000,
446 fluxErr=100,
447 filterName='r',
448 ccdVisitId=1233,
449 midPointTai=dateTime.get(system=dafBase.DateTime.MJD))
450 dia_sources = dia_sources.asAstropy().to_pandas()
451 dia_sources.rename(columns={"coord_ra": "ra",
452 "coord_dec": "decl",
453 "id": "diaSourceId",
454 "parent": "parentDiaSourceId"},
455 inplace=True)
456 dia_sources["ra"] = np.degrees(dia_sources["ra"])
457 dia_sources["decl"] = np.degrees(dia_sources["decl"])
458 return dia_objects, dia_sources
460 def test_associate_sources(self):
461 """Test the performance of the associate_sources method in
462 AssociationTask.
463 """
464 n_objects = 5
465 dia_objects = create_test_points_pandas(
466 point_locs_deg=[[0.04 * obj_idx, 0.04 * obj_idx]
467 for obj_idx in range(n_objects)],
468 start_id=0,
469 schema=self.dia_object_schema,
470 scatter_arcsec=-1,)
471 dia_objects.rename(columns={"coord_ra": "ra",
472 "coord_dec": "decl",
473 "id": "diaObjectId"},
474 inplace=True)
476 n_sources = 5
477 dia_sources = create_test_points_pandas(
478 point_locs_deg=[
479 [0.04 * (src_idx + 1),
480 0.04 * (src_idx + 1)]
481 for src_idx in range(n_sources)],
482 start_id=n_objects,
483 scatter_arcsec=0.1)
484 dia_sources.rename(columns={"coord_ra": "ra",
485 "coord_dec": "decl",
486 "id": "diaSourceId"},
487 inplace=True)
489 assoc_task = AssociationTask()
490 assoc_result = assoc_task.associate_sources(
491 dia_objects, dia_sources)
493 for test_obj_id, expected_obj_id in zip(
494 assoc_result.associated_dia_object_ids,
495 [1, 2, 3, 4, 9]):
496 self.assertEqual(test_obj_id, expected_obj_id)
498 def test_score_and_match(self):
499 """Test association between a set of sources and an existing
500 DIAObjectCollection.
502 This also tests that a DIASource that can't be associated within
503 tolerance is appended to the DIAObjectCollection as a new
504 DIAObject.
505 """
507 assoc_task = AssociationTask()
508 # Create a set of DIAObjects that contain only one DIASource
509 n_objects = 5
510 dia_objects = create_test_points_pandas(
511 point_locs_deg=[[0.04 * obj_idx, 0.04 * obj_idx]
512 for obj_idx in range(n_objects)],
513 start_id=0,
514 schema=self.dia_object_schema,
515 scatter_arcsec=-1,)
516 dia_objects.rename(columns={"coord_ra": "ra",
517 "coord_dec": "decl",
518 "id": "diaObjectId"},
519 inplace=True)
521 n_sources = 5
522 dia_sources = create_test_points_pandas(
523 point_locs_deg=[
524 [0.04 * (src_idx + 1),
525 0.04 * (src_idx + 1)]
526 for src_idx in range(n_sources)],
527 start_id=n_objects,
528 scatter_arcsec=-1)
529 dia_sources.rename(columns={"coord_ra": "ra",
530 "coord_dec": "decl",
531 "id": "diaSourceId"},
532 inplace=True)
534 score_struct = assoc_task.score(dia_objects,
535 dia_sources,
536 1.0 * geom.arcseconds)
537 self.assertFalse(np.isfinite(score_struct.scores[-1]))
538 for src_idx in range(4):
539 # Our scores should be extremely close to 0 but not exactly so due
540 # to machine noise.
541 self.assertAlmostEqual(score_struct.scores[src_idx], 0.0,
542 places=16)
544 # After matching each DIAObject should now contain 2 DIASources
545 # except the last DIAObject in this collection which should be
546 # newly created during the matching step and contain only one
547 # DIASource.
548 match_result = assoc_task.match(dia_objects, dia_sources, score_struct)
549 updated_ids = match_result.associated_dia_object_ids
550 self.assertEqual(len(updated_ids), 5)
551 self.assertEqual(match_result.n_updated_dia_objects, 4)
552 self.assertEqual(match_result.n_new_dia_objects, 1)
553 self.assertEqual(match_result.n_unassociated_dia_objects, 1)
555 # Test updating all DiaObjects
556 n_objects = 4
557 dia_objects = create_test_points_pandas(
558 point_locs_deg=[[0.04 * obj_idx, 0.04 * obj_idx]
559 for obj_idx in range(n_objects)],
560 start_id=0,
561 schema=self.dia_object_schema,
562 scatter_arcsec=-1,)
563 dia_objects.rename(columns={"coord_ra": "ra",
564 "coord_dec": "decl",
565 "id": "diaObjectId"},
566 inplace=True)
568 n_sources = 4
569 dia_sources = create_test_points_pandas(
570 point_locs_deg=[
571 [0.04 * src_idx,
572 0.04 * src_idx]
573 for src_idx in range(n_sources)],
574 start_id=n_objects,
575 scatter_arcsec=-1)
577 dia_sources.rename(columns={"coord_ra": "ra",
578 "coord_dec": "decl",
579 "id": "diaSourceId"},
580 inplace=True)
581 score_struct = assoc_task.score(dia_objects[1:],
582 dia_sources[:-1],
583 1.0 * geom.arcseconds)
584 match_result = assoc_task.match(dia_objects, dia_sources, score_struct)
585 updated_ids = match_result.associated_dia_object_ids
586 self.assertEqual(len(updated_ids), 4)
588 def test_remove_nan_dia_sources(self):
589 n_sources = 6
590 dia_sources = create_test_points_pandas(
591 point_locs_deg=[
592 [0.04 * (src_idx + 1),
593 0.04 * (src_idx + 1)]
594 for src_idx in range(n_sources)],
595 start_id=0,
596 scatter_arcsec=-1)
597 dia_sources.rename(columns={"coord_ra": "ra",
598 "coord_dec": "decl",
599 "id": "diaSourceId"},
600 inplace=True)
602 dia_sources.loc[2, "ra"] = np.nan
603 dia_sources.loc[3, "decl"] = np.nan
604 dia_sources.loc[4, "ra"] = np.nan
605 dia_sources.loc[4, "decl"] = np.nan
606 assoc_task = AssociationTask()
607 out_dia_sources = assoc_task.check_dia_source_radec(dia_sources)
608 self.assertEqual(len(out_dia_sources), n_sources - 3)
611class MemoryTester(lsst.utils.tests.MemoryTestCase):
612 pass
615def setup_module(module):
616 lsst.utils.tests.init()
619if __name__ == "__main__": 619 ↛ 620line 619 didn't jump to line 620, because the condition on line 619 was never true
620 lsst.utils.tests.init()
621 unittest.main()