Coverage for tests/test_association_task.py : 10%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import numpy as np
23import pandas as pd
24import unittest
26from lsst.afw.cameraGeom.testUtils import DetectorWrapper
27import lsst.afw.geom as afwGeom
28import lsst.afw.image as afwImage
29import lsst.afw.table as afwTable
30import lsst.daf.base as dafBase
31import lsst.geom as geom
32import lsst.sphgeom as sphgeom
33import lsst.utils.tests
35from lsst.ap.association import \
36 AssociationTask, \
37 make_dia_source_schema, \
38 make_dia_object_schema
41def create_test_points(point_locs_deg,
42 wcs=None,
43 start_id=0,
44 schema=None,
45 scatter_arcsec=1.0,
46 indexer_ids=None,
47 associated_ids=None):
48 """Create dummy DIASources or DIAObjects for use in our tests.
49 Parameters
50 ----------
51 point_locs_deg : array-like (N, 2) of `float`s
52 Positions of the test points to create in RA, DEC.
53 wcs : `lsst.afw.geom.SkyWcs`
54 Wcs to convert RA/Dec to x/y if provided.
55 start_id : `int`
56 Unique id of the first object to create. The remaining sources are
57 incremented by one from the first id.
58 schema : `lsst.afw.table.Schema`
59 Schema of the objects to create. Defaults to the DIASource schema.
60 scatter_arcsec : `float`
61 Scatter to add to the position of each DIASource.
62 indexer_ids : `list` of `ints`s
63 Id numbers of pixelization indexer to store. Must be the same length
64 as the first dimension of point_locs_deg.
65 associated_ids : `list` of `ints`s
66 Id numbers of associated DIAObjects to store. Must be the same length
67 as the first dimension of point_locs_deg.
68 Returns
69 -------
70 test_points : `lsst.afw.table.SourceCatalog`
71 Catalog of points to test.
72 """
73 if schema is None:
74 schema = make_dia_source_schema()
75 sources = afwTable.SourceCatalog(schema)
77 for src_idx, (ra, dec,) in enumerate(point_locs_deg):
78 src = sources.addNew()
79 src['id'] = src_idx + start_id
80 coord = geom.SpherePoint(ra, dec, geom.degrees)
81 if scatter_arcsec > 0.0:
82 coord = coord.offset(
83 np.random.rand() * 360 * geom.degrees,
84 np.random.rand() * scatter_arcsec * geom.arcseconds)
85 if indexer_ids is not None:
86 src['pixelId'] = indexer_ids[src_idx]
87 if associated_ids is not None:
88 src['diaObjectId'] = associated_ids[src_idx]
89 src.setCoord(coord)
91 if wcs is not None:
92 xyCentroid = wcs.skykToPixel(coord)
93 src.set("x", xyCentroid.getX())
94 src.set("y", xyCentroid.getY())
96 return sources
99def create_test_points_pandas(point_locs_deg,
100 wcs=None,
101 start_id=0,
102 schema=None,
103 scatter_arcsec=1.0,
104 indexer_ids=None,
105 associated_ids=None):
106 """Create dummy DIASources or DIAObjects for use in our tests.
107 Parameters
108 ----------
109 point_locs_deg : array-like (N, 2) of `float`s
110 Positions of the test points to create in RA, DEC.
111 wcs : `lsst.afw.geom.SkyWcs`
112 Wcs to convert RA/Dec to x/y if provided.
113 start_id : `int`
114 Unique id of the first object to create. The remaining sources are
115 incremented by one from the first id.
116 schema : `lsst.afw.table.Schema`
117 Schema of the objects to create. Defaults to the DIASource schema.
118 scatter_arcsec : `float`
119 Scatter to add to the position of each DIASource.
120 indexer_ids : `list` of `ints`s
121 Id numbers of pixelization indexer to store. Must be the same length
122 as the first dimension of point_locs_deg.
123 associated_ids : `list` of `ints`s
124 Id numbers of associated DIAObjects to store. Must be the same length
125 as the first dimension of point_locs_deg.
126 Returns
127 -------
128 test_points : `pandas.DataFrame`
129 Catalog of points to test.
130 """
131 if schema is None:
132 schema = make_dia_source_schema()
133 sources = afwTable.SourceCatalog(schema)
135 for src_idx, (ra, dec,) in enumerate(point_locs_deg):
136 src = sources.addNew()
137 src['id'] = src_idx + start_id
138 coord = geom.SpherePoint(ra, dec, geom.degrees)
139 if scatter_arcsec > 0.0:
140 coord = coord.offset(
141 np.random.rand() * 360 * geom.degrees,
142 np.random.rand() * scatter_arcsec * geom.arcseconds)
143 if indexer_ids is not None:
144 src['pixelId'] = indexer_ids[src_idx]
145 if associated_ids is not None:
146 src['diaObjectId'] = associated_ids[src_idx]
147 src.setCoord(coord)
149 if wcs is not None:
150 xyCentroid = wcs.skykToPixel(coord)
151 src.set("x", xyCentroid.getX())
152 src.set("y", xyCentroid.getY())
154 sources = sources.asAstropy().to_pandas()
156 return sources
159class TestAssociationTask(unittest.TestCase):
161 def setUp(self):
162 """Create a sqlite3 database with default tables and schemas.
163 """
164 self.filter_names = ["u", "g", "r", "i", "z"]
165 self.dia_object_schema = make_dia_object_schema()
167 # metadata taken from CFHT data
168 # v695856-e0/v695856-e0-c000-a00.sci_img.fits
170 self.metadata = dafBase.PropertySet()
172 self.metadata.set("SIMPLE", "T")
173 self.metadata.set("BITPIX", -32)
174 self.metadata.set("NAXIS", 2)
175 self.metadata.set("NAXIS1", 1024)
176 self.metadata.set("NAXIS2", 1153)
177 self.metadata.set("RADECSYS", 'FK5')
178 self.metadata.set("EQUINOX", 2000.)
180 self.metadata.setDouble("CRVAL1", 215.604025685476)
181 self.metadata.setDouble("CRVAL2", 53.1595451514076)
182 self.metadata.setDouble("CRPIX1", 1109.99981456774)
183 self.metadata.setDouble("CRPIX2", 560.018167811613)
184 self.metadata.set("CTYPE1", 'RA---SIN')
185 self.metadata.set("CTYPE2", 'DEC--SIN')
187 self.metadata.setDouble("CD1_1", 5.10808596133527E-05)
188 self.metadata.setDouble("CD1_2", 1.85579539217196E-07)
189 self.metadata.setDouble("CD2_2", -5.10281493481982E-05)
190 self.metadata.setDouble("CD2_1", -8.27440751733828E-07)
192 self.wcs = afwGeom.makeSkyWcs(self.metadata)
193 self.exposure = afwImage.makeExposure(
194 afwImage.makeMaskedImageFromArrays(np.ones((1024, 1153))),
195 self.wcs)
196 detector = DetectorWrapper(id=23, bbox=self.exposure.getBBox()).detector
197 visit = afwImage.VisitInfo(
198 exposureId=1234,
199 exposureTime=200.,
200 date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
201 dafBase.DateTime.Timescale.TAI))
202 self.exposure.setDetector(detector)
203 self.exposure.getInfo().setVisitInfo(visit)
204 self.exposure.setFilterLabel(afwImage.FilterLabel(band='g'))
205 self.flux0 = 10000
206 self.flux0_err = 100
207 self.exposure.setPhotoCalib(
208 afwImage.PhotoCalib(self.flux0, self.flux0_err))
210 bbox = geom.Box2D(self.exposure.getBBox())
211 wcs = self.exposure.getWcs()
213 self.pixelator = sphgeom.HtmPixelization(20)
214 region = sphgeom.ConvexPolygon([wcs.pixelToSky(pp).getVector()
215 for pp in bbox.getCorners()])
217 indices = self.pixelator.envelope(region, 64)
218 # Index types must be cast to int to work with dax_apdb.
219 self.index_ranges = indices.ranges()
221 def tearDown(self):
222 """Delete the database after we are done with it.
223 """
224 del self.metadata
225 del self.wcs
226 del self.exposure
228 def test_run(self):
229 """Test the run method with a database that already exists and
230 contains DIAObjects and Sources.
231 """
232 dia_objects = self._run_association_and_retrieve_objects(True)
233 not_updated_idx = 0
234 updated_idx_start = 1
235 new_idx_start = 5
236 total_expected_dia_objects = 10
237 self.assertEqual(len(dia_objects), total_expected_dia_objects)
239 # Test to make sure the number of DIAObjects have been properly
240 # associated within the db.
241 for obj_idx, (df_idx, dia_object) in enumerate(dia_objects.iterrows()):
242 if df_idx == not_updated_idx:
243 # Test the DIAObject we expect to not be associated with any
244 # new DIASources.
245 self.assertEqual(df_idx, obj_idx)
246 elif updated_idx_start <= df_idx < new_idx_start:
247 # Test that associating to the existing DIAObjects went
248 # as planned and test that the IDs of the newly associated
249 # DIASources is correct.
250 self.assertEqual(df_idx, obj_idx)
251 else:
252 self.assertEqual(df_idx, obj_idx + 4 + 5)
254 def test_run_no_existing_objects(self):
255 """Test the run method with a completely empty database.
256 """
257 dia_objects = self._run_association_and_retrieve_objects(False)
258 total_expected_dia_objects = 9
259 self.assertEqual(len(dia_objects),
260 total_expected_dia_objects)
261 for obj_idx, (df_idx, output_dia_object) in enumerate(dia_objects.iterrows()):
262 self.assertEqual(df_idx, obj_idx + 10)
264 def test_run_dup_diaObjects(self):
265 """Test that duplicate objects being run through association throw the
266 correct error.
267 """
268 with self.assertRaises(RuntimeError):
269 self._run_association_and_retrieve_objects(create_objects=True,
270 dupDiaObjects=True)
272 def _run_association_and_retrieve_objects(self,
273 create_objects=False,
274 dupDiaObjects=False):
275 """Convenience method for testing the Association run method.
277 Parameters
278 ----------
279 create_objects : `bool`
280 Boolean specifying if seed DIAObjects and DIASources should be
281 inserted into the database before association.
282 dupDiaObjects : `bool`
283 Add duplicate diaObjects into processing to force an error. Must
284 be used with ``create_objects`` equal to True.
286 Return
287 ------
288 dia_objects : `lsst.afw.table.SourceCatalog`
289 Final set of DIAObjects to be tested.
290 """
291 if create_objects:
292 diaObjects, diaSourceHistory = \
293 self._create_dia_objects_and_sources()
294 else:
295 diaObjects = pd.DataFrame(columns=["diaObjectId"])
296 diaSourceHistory = pd.DataFrame(columns=["diaObjectId",
297 "filterName",
298 "diaSourceId"])
299 diaObjects.set_index("diaObjectId",
300 inplace=True,
301 drop=False)
302 diaSourceHistory.set_index(["diaObjectId",
303 "filterName",
304 "diaSourceId"],
305 inplace=True,
306 drop=False)
308 source_centers = [
309 [self.wcs.pixelToSky(idx, idx).getRa().asDegrees(),
310 self.wcs.pixelToSky(idx, idx).getDec().asDegrees()]
311 for idx in np.linspace(1, 1000, 10)[1:]]
312 dia_sources = create_test_points(
313 point_locs_deg=source_centers,
314 start_id=10,
315 scatter_arcsec=-1)
316 for dia_source in dia_sources:
317 self._set_source_values(
318 dia_source=dia_source,
319 flux=10000,
320 fluxErr=100,
321 filterName=self.exposure.getFilterLabel().bandLabel,
322 ccdVisitId=self.exposure.getInfo().getVisitInfo().getExposureId(),
323 midPointTai=self.exposure.getInfo().getVisitInfo().getDate().get(system=dafBase.DateTime.MJD))
325 assoc_task = AssociationTask()
327 diaSources = dia_sources.asAstropy().to_pandas()
328 diaSources.rename(columns={"coord_ra": "ra",
329 "coord_dec": "decl",
330 "id": "diaSourceId",
331 "parent": "parentDiaSourceId"},
332 inplace=True)
333 diaSources["ra"] = np.degrees(diaSources["ra"])
334 diaSources["decl"] = np.degrees(diaSources["decl"])
336 if len(diaObjects) == 0:
337 diaSourceHistory = pd.DataFrame(columns=["diaObjectId",
338 "filterName",
339 "diaSourceId"])
340 diaSourceHistory.set_index(
341 ["diaObjectId", "filterName", "diaSourceId"],
342 drop=False,
343 inplace=True)
344 if dupDiaObjects:
345 diaObjects = diaObjects.append(diaObjects.iloc[[0, -1]],
346 ignore_index=True)
347 diaObjects.set_index("diaObjectId",
348 inplace=True,
349 drop=False)
351 results = assoc_task.run(diaSources,
352 diaObjects,
353 diaSourceHistory)
354 return results.diaObjects
356 def _set_source_values(self, dia_source, flux, fluxErr, filterName,
357 ccdVisitId, midPointTai):
358 """Set fluxes and visit info for DiaSources.
360 Parameters
361 ----------
362 dia_source : `lsst.afw.table.SourceRecord`
363 SourceRecord object to edit.
364 flux : `double`
365 Flux of DiaSource
366 fluxErr : `double`
367 Flux error of DiaSource
368 filterName : `string`
369 Name of filter for flux.
370 ccdVisitId : `int`
371 Integer id of this ccd/visit.
372 midPointTai : `double`
373 Time of observation
374 """
375 dia_source['ccdVisitId'] = ccdVisitId
376 dia_source["midPointTai"] = midPointTai
377 dia_source["psFlux"] = flux / self.flux0
378 dia_source["psFluxErr"] = np.sqrt(
379 (fluxErr / self.flux0) ** 2
380 + (flux * self.flux0_err / self.flux0 ** 2) ** 2)
381 dia_source["apFlux"] = flux / self.flux0
382 dia_source["apFluxErr"] = np.sqrt(
383 (fluxErr / self.flux0) ** 2
384 + (flux * self.flux0_err / self.flux0 ** 2) ** 2)
385 dia_source["totFlux"] = flux / self.flux0
386 dia_source["totFluxErr"] = np.sqrt(
387 (fluxErr / self.flux0) ** 2
388 + (flux * self.flux0_err / self.flux0 ** 2) ** 2)
389 dia_source["filterName"] = filterName
390 dia_source["x"] = 0.
391 dia_source["y"] = 0.
393 def _create_dia_objects_and_sources(self):
394 """Method for storing a set of test DIAObjects and sources into
395 the L1 database.
396 """
398 # This should create a DB of 5 DIAObjects with 2 DIASources associated
399 # to them. The DIASources are "observed" in g and r.
401 # Create DIObjects, give them fluxes, and store them
402 n_objects = 5
403 object_centers = np.array([
404 [self.wcs.pixelToSky(idx, idx).getRa().asDegrees(),
405 self.wcs.pixelToSky(idx, idx).getDec().asDegrees()]
406 for idx in np.linspace(1, 1000, 10)])
407 dia_objects = create_test_points(
408 point_locs_deg=object_centers[:n_objects],
409 start_id=0,
410 schema=self.dia_object_schema,
411 scatter_arcsec=-1,)
412 # Set the DIAObject fluxes and number of associated sources.
413 for dia_object in dia_objects:
414 dia_object["nDiaSources"] = 2
415 for filter_name in self.filter_names:
416 sphPoint = geom.SpherePoint(dia_object.getCoord())
417 htmIndex = self.pixelator.index(sphPoint.getVector())
418 dia_object["pixelId"] = htmIndex
419 dia_object['%sPSFluxMean' % filter_name] = 1
420 dia_object['%sPSFluxMeanErr' % filter_name] = 1
421 dia_object['%sPSFluxSigma' % filter_name] = 1
422 dia_object['%sPSFluxNdata' % filter_name] = 1
423 dia_objects = dia_objects.asAstropy().to_pandas()
424 dia_objects.rename(columns={"coord_ra": "ra",
425 "coord_dec": "decl",
426 "id": "diaObjectId"},
427 inplace=True)
428 dia_objects["ra"] = np.degrees(dia_objects["ra"])
429 dia_objects["decl"] = np.degrees(dia_objects["decl"])
431 dateTime = dafBase.DateTime("2014-05-13T16:00:00.000000000",
432 dafBase.DateTime.Timescale.TAI)
434 # Create DIASources, update their ccdVisitId and fluxes, and store
435 # them.
436 dia_sources = create_test_points(
437 point_locs_deg=np.concatenate(
438 [object_centers[:n_objects], object_centers[:n_objects]]),
439 start_id=0,
440 scatter_arcsec=-1,
441 associated_ids=[0, 1, 2, 3, 4,
442 0, 1, 2, 3, 4])
443 for src_idx, dia_source in enumerate(dia_sources):
444 if src_idx < n_objects:
445 self._set_source_values(
446 dia_source=dia_source,
447 flux=10000,
448 fluxErr=100,
449 filterName='g',
450 ccdVisitId=1232,
451 midPointTai=dateTime.get(system=dafBase.DateTime.MJD))
452 else:
453 self._set_source_values(
454 dia_source=dia_source,
455 flux=10000,
456 fluxErr=100,
457 filterName='r',
458 ccdVisitId=1233,
459 midPointTai=dateTime.get(system=dafBase.DateTime.MJD))
460 dia_sources = dia_sources.asAstropy().to_pandas()
461 dia_sources.rename(columns={"coord_ra": "ra",
462 "coord_dec": "decl",
463 "id": "diaSourceId",
464 "parent": "parentDiaSourceId"},
465 inplace=True)
466 dia_sources["ra"] = np.degrees(dia_sources["ra"])
467 dia_sources["decl"] = np.degrees(dia_sources["decl"])
468 return dia_objects, dia_sources
470 def test_associate_sources(self):
471 """Test the performance of the associate_sources method in
472 AssociationTask.
473 """
474 n_objects = 5
475 dia_objects = create_test_points_pandas(
476 point_locs_deg=[[0.04 * obj_idx, 0.04 * obj_idx]
477 for obj_idx in range(n_objects)],
478 start_id=0,
479 schema=self.dia_object_schema,
480 scatter_arcsec=-1,)
481 dia_objects.rename(columns={"coord_ra": "ra",
482 "coord_dec": "decl",
483 "id": "diaObjectId"},
484 inplace=True)
486 n_sources = 5
487 dia_sources = create_test_points_pandas(
488 point_locs_deg=[
489 [0.04 * (src_idx + 1),
490 0.04 * (src_idx + 1)]
491 for src_idx in range(n_sources)],
492 start_id=n_objects,
493 scatter_arcsec=0.1)
494 dia_sources.rename(columns={"coord_ra": "ra",
495 "coord_dec": "decl",
496 "id": "diaSourceId"},
497 inplace=True)
499 assoc_task = AssociationTask()
500 assoc_result = assoc_task.associate_sources(
501 dia_objects, dia_sources)
503 for test_obj_id, expected_obj_id in zip(
504 assoc_result.associated_dia_object_ids,
505 [1, 2, 3, 4, 9]):
506 self.assertEqual(test_obj_id, expected_obj_id)
508 def test_score_and_match(self):
509 """Test association between a set of sources and an existing
510 DIAObjectCollection.
512 This also tests that a DIASource that can't be associated within
513 tolerance is appended to the DIAObjectCollection as a new
514 DIAObject.
515 """
517 assoc_task = AssociationTask()
518 # Create a set of DIAObjects that contain only one DIASource
519 n_objects = 5
520 dia_objects = create_test_points_pandas(
521 point_locs_deg=[[0.04 * obj_idx, 0.04 * obj_idx]
522 for obj_idx in range(n_objects)],
523 start_id=0,
524 schema=self.dia_object_schema,
525 scatter_arcsec=-1,)
526 dia_objects.rename(columns={"coord_ra": "ra",
527 "coord_dec": "decl",
528 "id": "diaObjectId"},
529 inplace=True)
531 n_sources = 5
532 dia_sources = create_test_points_pandas(
533 point_locs_deg=[
534 [0.04 * (src_idx + 1),
535 0.04 * (src_idx + 1)]
536 for src_idx in range(n_sources)],
537 start_id=n_objects,
538 scatter_arcsec=-1)
539 dia_sources.rename(columns={"coord_ra": "ra",
540 "coord_dec": "decl",
541 "id": "diaSourceId"},
542 inplace=True)
544 score_struct = assoc_task.score(dia_objects,
545 dia_sources,
546 1.0 * geom.arcseconds)
547 self.assertFalse(np.isfinite(score_struct.scores[-1]))
548 for src_idx in range(4):
549 # Our scores should be extremely close to 0 but not exactly so due
550 # to machine noise.
551 self.assertAlmostEqual(score_struct.scores[src_idx], 0.0,
552 places=16)
554 # After matching each DIAObject should now contain 2 DIASources
555 # except the last DIAObject in this collection which should be
556 # newly created during the matching step and contain only one
557 # DIASource.
558 match_result = assoc_task.match(dia_objects, dia_sources, score_struct)
559 updated_ids = match_result.associated_dia_object_ids
560 self.assertEqual(len(updated_ids), 5)
561 self.assertEqual(match_result.n_updated_dia_objects, 4)
562 self.assertEqual(match_result.n_new_dia_objects, 1)
563 self.assertEqual(match_result.n_unassociated_dia_objects, 1)
565 # Test updating all DiaObjects
566 n_objects = 4
567 dia_objects = create_test_points_pandas(
568 point_locs_deg=[[0.04 * obj_idx, 0.04 * obj_idx]
569 for obj_idx in range(n_objects)],
570 start_id=0,
571 schema=self.dia_object_schema,
572 scatter_arcsec=-1,)
573 dia_objects.rename(columns={"coord_ra": "ra",
574 "coord_dec": "decl",
575 "id": "diaObjectId"},
576 inplace=True)
578 n_sources = 4
579 dia_sources = create_test_points_pandas(
580 point_locs_deg=[
581 [0.04 * src_idx,
582 0.04 * src_idx]
583 for src_idx in range(n_sources)],
584 start_id=n_objects,
585 scatter_arcsec=-1)
587 dia_sources.rename(columns={"coord_ra": "ra",
588 "coord_dec": "decl",
589 "id": "diaSourceId"},
590 inplace=True)
591 score_struct = assoc_task.score(dia_objects[1:],
592 dia_sources[:-1],
593 1.0 * geom.arcseconds)
594 match_result = assoc_task.match(dia_objects, dia_sources, score_struct)
595 updated_ids = match_result.associated_dia_object_ids
596 self.assertEqual(len(updated_ids), 4)
598 def test_remove_nan_dia_sources(self):
599 n_sources = 6
600 dia_sources = create_test_points_pandas(
601 point_locs_deg=[
602 [0.04 * (src_idx + 1),
603 0.04 * (src_idx + 1)]
604 for src_idx in range(n_sources)],
605 start_id=0,
606 scatter_arcsec=-1)
607 dia_sources.rename(columns={"coord_ra": "ra",
608 "coord_dec": "decl",
609 "id": "diaSourceId"},
610 inplace=True)
612 dia_sources.loc[2, "ra"] = np.nan
613 dia_sources.loc[3, "decl"] = np.nan
614 dia_sources.loc[4, "ra"] = np.nan
615 dia_sources.loc[4, "decl"] = np.nan
616 assoc_task = AssociationTask()
617 out_dia_sources = assoc_task.check_dia_source_radec(dia_sources)
618 self.assertEqual(len(out_dia_sources), n_sources - 3)
621class MemoryTester(lsst.utils.tests.MemoryTestCase):
622 pass
625def setup_module(module):
626 lsst.utils.tests.init()
629if __name__ == "__main__": 629 ↛ 630line 629 didn't jump to line 630, because the condition on line 629 was never true
630 lsst.utils.tests.init()
631 unittest.main()