Coverage for python/lsst/ap/association/association.py: 27%
83 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-18 03:13 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-18 03:13 -0800
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""A simple implementation of source association task for ap_verify.
23"""
25__all__ = ["AssociationConfig", "AssociationTask"]
27import numpy as np
28import pandas as pd
29from scipy.spatial import cKDTree
31import lsst.geom as geom
32import lsst.pex.config as pexConfig
33import lsst.pipe.base as pipeBase
34from lsst.utils.timer import timeMethod
36# Enforce an error for unsafe column/array value setting in pandas.
37pd.options.mode.chained_assignment = 'raise'
40class AssociationConfig(pexConfig.Config):
41 """Config class for AssociationTask.
42 """
43 maxDistArcSeconds = pexConfig.Field(
44 dtype=float,
45 doc='Maximum distance in arcseconds to test for a DIASource to be a '
46 'match to a DIAObject.',
47 default=1.0,
48 )
51class AssociationTask(pipeBase.Task):
52 """Associate DIAOSources into existing DIAObjects.
54 This task performs the association of detected DIASources in a visit
55 with the previous DIAObjects detected over time. It also creates new
56 DIAObjects out of DIASources that cannot be associated with previously
57 detected DIAObjects.
58 """
60 ConfigClass = AssociationConfig
61 _DefaultName = "association"
63 @timeMethod
64 def run(self,
65 diaSources,
66 diaObjects):
67 """Associate the new DiaSources with existing DiaObjects.
69 Parameters
70 ----------
71 diaSources : `pandas.DataFrame`
72 New DIASources to be associated with existing DIAObjects.
73 diaObjects : `pandas.DataFrame`
74 Existing diaObjects from the Apdb.
76 Returns
77 -------
78 result : `lsst.pipe.base.Struct`
79 Results struct with components.
81 - ``"matchedDiaSources"`` : DiaSources that were matched. Matched
82 Sources have their diaObjectId updated and set to the id of the
83 diaObject they were matched to. (`pandas.DataFrame`)
84 - ``"unAssocDiaSources"`` : DiaSources that were not matched.
85 Unassociated sources have their diaObject set to 0 as they
86 were not associated with any existing DiaObjects.
87 (`pandas.DataFrame`)
88 - ``"nUpdatedDiaObjects"`` : Number of DiaObjects that were
89 matched to new DiaSources. (`int`)
90 - ``"nUnassociatedDiaObjects"`` : Number of DiaObjects that were
91 not matched a new DiaSource. (`int`)
92 """
93 diaSources = self.check_dia_source_radec(diaSources)
94 if len(diaObjects) == 0:
95 return pipeBase.Struct(
96 matchedDiaSources=pd.DataFrame(columns=diaSources.columns),
97 unAssocDiaSources=diaSources,
98 nUpdatedDiaObjects=0,
99 nUnassociatedDiaObjects=0)
101 matchResult = self.associate_sources(diaObjects, diaSources)
103 mask = matchResult.diaSources["diaObjectId"] != 0
105 return pipeBase.Struct(
106 matchedDiaSources=matchResult.diaSources[mask].reset_index(drop=True),
107 unAssocDiaSources=matchResult.diaSources[~mask].reset_index(drop=True),
108 nUpdatedDiaObjects=matchResult.nUpdatedDiaObjects,
109 nUnassociatedDiaObjects=matchResult.nUnassociatedDiaObjects)
111 def check_dia_source_radec(self, dia_sources):
112 """Check that all DiaSources have non-NaN values for RA/DEC.
114 If one or more DiaSources are found to have NaN values, throw a
115 warning to the log with the ids of the offending sources. Drop them
116 from the table.
118 Parameters
119 ----------
120 dia_sources : `pandas.DataFrame`
121 Input DiaSources to check for NaN values.
123 Returns
124 -------
125 trimmed_sources : `pandas.DataFrame`
126 DataFrame of DiaSources trimmed of all entries with NaN values for
127 RA/DEC.
128 """
129 nan_mask = (dia_sources.loc[:, "ra"].isnull()
130 | dia_sources.loc[:, "decl"].isnull())
131 if np.any(nan_mask):
132 nan_idxs = np.argwhere(nan_mask.to_numpy()).flatten()
133 for nan_idx in nan_idxs:
134 self.log.warning(
135 "DiaSource %i has NaN value for RA/DEC, "
136 "dropping from association." %
137 dia_sources.loc[nan_idx, "diaSourceId"])
138 dia_sources = dia_sources[~nan_mask]
139 return dia_sources
141 @timeMethod
142 def associate_sources(self, dia_objects, dia_sources):
143 """Associate the input DIASources with the catalog of DIAObjects.
145 DiaObject DataFrame must be indexed on ``diaObjectId``.
147 Parameters
148 ----------
149 dia_objects : `pandas.DataFrame`
150 Catalog of DIAObjects to attempt to associate the input
151 DIASources into.
152 dia_sources : `pandas.DataFrame`
153 DIASources to associate into the DIAObjectCollection.
155 Returns
156 -------
157 result : `lsst.pipe.base.Struct`
158 Results struct with components.
160 - ``"diaSources"`` : Full set of diaSources both matched and not.
161 (`pandas.DataFrame`)
162 - ``"nUpdatedDiaObjects"`` : Number of DiaObjects that were
163 associated. (`int`)
164 - ``"nUnassociatedDiaObjects"`` : Number of DiaObjects that were
165 not matched a new DiaSource. (`int`)
166 """
167 scores = self.score(
168 dia_objects, dia_sources,
169 self.config.maxDistArcSeconds * geom.arcseconds)
170 match_result = self.match(dia_objects, dia_sources, scores)
172 return match_result
174 @timeMethod
175 def score(self, dia_objects, dia_sources, max_dist):
176 """Compute a quality score for each dia_source/dia_object pair
177 between this catalog of DIAObjects and the input DIASource catalog.
179 ``max_dist`` sets maximum separation in arcseconds to consider a
180 dia_source a possible match to a dia_object. If the pair is
181 beyond this distance no score is computed.
183 Parameters
184 ----------
185 dia_objects : `pandas.DataFrame`
186 A contiguous catalog of DIAObjects to score against dia_sources.
187 dia_sources : `pandas.DataFrame`
188 A contiguous catalog of dia_sources to "score" based on distance
189 and (in the future) other metrics.
190 max_dist : `lsst.geom.Angle`
191 Maximum allowed distance to compute a score for a given DIAObject
192 DIASource pair.
194 Returns
195 -------
196 result : `lsst.pipe.base.Struct`
197 Results struct with components:
199 - ``"scores"``: array of floats of match quality updated DIAObjects
200 (array-like of `float`).
201 - ``"obj_idxs"``: indexes of the matched DIAObjects in the catalog.
202 (array-like of `int`)
203 - ``"obj_ids"``: array of floats of match quality updated DIAObjects
204 (array-like of `int`).
206 Default values for these arrays are
207 INF, -1, and -1 respectively for unassociated sources.
208 """
209 scores = np.full(len(dia_sources), np.inf, dtype=np.float64)
210 obj_idxs = np.full(len(dia_sources), -1, dtype=np.int64)
211 obj_ids = np.full(len(dia_sources), 0, dtype=np.int64)
213 if len(dia_objects) == 0:
214 return pipeBase.Struct(
215 scores=scores,
216 obj_idxs=obj_idxs,
217 obj_ids=obj_ids)
219 spatial_tree = self._make_spatial_tree(dia_objects)
221 max_dist_rad = max_dist.asRadians()
223 vectors = self._radec_to_xyz(dia_sources)
225 scores, obj_idxs = spatial_tree.query(
226 vectors,
227 distance_upper_bound=max_dist_rad)
228 matched_src_idxs = np.argwhere(np.isfinite(scores))
229 obj_ids[matched_src_idxs] = dia_objects.index.to_numpy()[
230 obj_idxs[matched_src_idxs]]
232 return pipeBase.Struct(
233 scores=scores,
234 obj_idxs=obj_idxs,
235 obj_ids=obj_ids)
237 def _make_spatial_tree(self, dia_objects):
238 """Create a searchable kd-tree the input dia_object positions.
240 Parameters
241 ----------
242 dia_objects : `pandas.DataFrame`
243 A catalog of DIAObjects to create the tree from.
245 Returns
246 -------
247 kd_tree : `scipy.spatical.cKDTree`
248 Searchable kd-tree created from the positions of the DIAObjects.
249 """
250 vectors = self._radec_to_xyz(dia_objects)
251 return cKDTree(vectors)
253 def _radec_to_xyz(self, catalog):
254 """Convert input ra/dec coordinates to spherical unit-vectors.
256 Parameters
257 ----------
258 catalog : `pandas.DataFrame`
259 Catalog to produce spherical unit-vector from.
261 Returns
262 -------
263 vectors : `numpy.ndarray`, (N, 3)
264 Output unit-vectors
265 """
266 ras = np.radians(catalog["ra"])
267 decs = np.radians(catalog["decl"])
268 vectors = np.empty((len(ras), 3))
270 sin_dec = np.sin(np.pi / 2 - decs)
271 vectors[:, 0] = sin_dec * np.cos(ras)
272 vectors[:, 1] = sin_dec * np.sin(ras)
273 vectors[:, 2] = np.cos(np.pi / 2 - decs)
275 return vectors
277 @timeMethod
278 def match(self, dia_objects, dia_sources, score_struct):
279 """Match DIAsources to DiaObjects given a score.
281 Parameters
282 ----------
283 dia_objects : `pandas.DataFrame`
284 A SourceCatalog of DIAObjects to associate to DIASources.
285 dia_sources : `pandas.DataFrame`
286 A contiguous catalog of dia_sources for which the set of scores
287 has been computed on with DIAObjectCollection.score.
288 score_struct : `lsst.pipe.base.Struct`
289 Results struct with components:
291 - ``"scores"``: array of floats of match quality
292 updated DIAObjects (array-like of `float`).
293 - ``"obj_ids"``: array of floats of match quality
294 updated DIAObjects (array-like of `int`).
295 - ``"obj_idxs"``: indexes of the matched DIAObjects in the catalog.
296 (array-like of `int`)
298 Default values for these arrays are
299 INF, -1 and -1 respectively for unassociated sources.
301 Returns
302 -------
303 result : `lsst.pipe.base.Struct`
304 Results struct with components.
306 - ``"diaSources"`` : Full set of diaSources both matched and not.
307 (`pandas.DataFrame`)
308 - ``"nUpdatedDiaObjects"`` : Number of DiaObjects that were
309 associated. (`int`)
310 - ``"nUnassociatedDiaObjects"`` : Number of DiaObjects that were
311 not matched a new DiaSource. (`int`)
312 """
313 n_previous_dia_objects = len(dia_objects)
314 used_dia_object = np.zeros(n_previous_dia_objects, dtype=bool)
315 used_dia_source = np.zeros(len(dia_sources), dtype=bool)
316 associated_dia_object_ids = np.zeros(len(dia_sources),
317 dtype=np.uint64)
318 n_updated_dia_objects = 0
320 # We sort from best match to worst to effectively perform a
321 # "handshake" match where both the DIASources and DIAObjects agree
322 # their the best match. By sorting this way, scores with NaN (those
323 # sources that have no match and will create new DIAObjects) will be
324 # placed at the end of the array.
325 score_args = score_struct.scores.argsort(axis=None)
326 for score_idx in score_args:
327 if not np.isfinite(score_struct.scores[score_idx]):
328 # Thanks to the sorting the rest of the sources will be
329 # NaN for their score. We therefore exit the loop to append
330 # sources to a existing DIAObject, leaving these for
331 # the loop creating new objects.
332 break
333 dia_obj_idx = score_struct.obj_idxs[score_idx]
334 if used_dia_object[dia_obj_idx]:
335 continue
336 used_dia_object[dia_obj_idx] = True
337 used_dia_source[score_idx] = True
338 obj_id = score_struct.obj_ids[score_idx]
339 associated_dia_object_ids[score_idx] = obj_id
340 dia_sources.loc[score_idx, "diaObjectId"] = obj_id
341 n_updated_dia_objects += 1
343 return pipeBase.Struct(
344 diaSources=dia_sources,
345 nUpdatedDiaObjects=n_updated_dia_objects,
346 nUnassociatedDiaObjects=(n_previous_dia_objects
347 - n_updated_dia_objects))