Coverage for python/lsst/ap/association/association.py: 32%
83 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 09:33 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 09:33 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""A simple implementation of source association task for ap_verify.
23"""
25__all__ = ["AssociationConfig", "AssociationTask"]
27import numpy as np
28import pandas as pd
29from scipy.spatial import cKDTree
31import lsst.geom as geom
32import lsst.pex.config as pexConfig
33import lsst.pipe.base as pipeBase
34from lsst.utils.timer import timeMethod
36# Enforce an error for unsafe column/array value setting in pandas.
37pd.options.mode.chained_assignment = 'raise'
40class AssociationConfig(pexConfig.Config):
41 """Config class for AssociationTask.
42 """
44 maxDistArcSeconds = pexConfig.Field(
45 dtype=float,
46 doc="Maximum distance in arcseconds to test for a DIASource to be a "
47 "match to a DIAObject.",
48 default=1.0,
49 )
52class AssociationTask(pipeBase.Task):
53 """Associate DIAOSources into existing DIAObjects.
55 This task performs the association of detected DIASources in a visit
56 with the previous DIAObjects detected over time. It also creates new
57 DIAObjects out of DIASources that cannot be associated with previously
58 detected DIAObjects.
59 """
61 ConfigClass = AssociationConfig
62 _DefaultName = "association"
64 @timeMethod
65 def run(self,
66 diaSources,
67 diaObjects):
68 """Associate the new DiaSources with existing DiaObjects.
70 Parameters
71 ----------
72 diaSources : `pandas.DataFrame`
73 New DIASources to be associated with existing DIAObjects.
74 diaObjects : `pandas.DataFrame`
75 Existing diaObjects from the Apdb.
77 Returns
78 -------
79 result : `lsst.pipe.base.Struct`
80 Results struct with components.
82 - ``matchedDiaSources`` : DiaSources that were matched. Matched
83 Sources have their diaObjectId updated and set to the id of the
84 diaObject they were matched to. (`pandas.DataFrame`)
85 - ``unAssocDiaSources`` : DiaSources that were not matched.
86 Unassociated sources have their diaObject set to 0 as they
87 were not associated with any existing DiaObjects.
88 (`pandas.DataFrame`)
89 - ``nUpdatedDiaObjects`` : Number of DiaObjects that were
90 matched to new DiaSources. (`int`)
91 - ``nUnassociatedDiaObjects`` : Number of DiaObjects that were
92 not matched a new DiaSource. (`int`)
93 """
94 diaSources = self.check_dia_source_radec(diaSources)
96 if len(diaObjects) == 0:
97 return pipeBase.Struct(
98 matchedDiaSources=pd.DataFrame(columns=diaSources.columns),
99 unAssocDiaSources=diaSources,
100 nUpdatedDiaObjects=0,
101 nUnassociatedDiaObjects=0)
103 matchResult = self.associate_sources(diaObjects, diaSources)
105 mask = matchResult.diaSources["diaObjectId"] != 0
107 return pipeBase.Struct(
108 matchedDiaSources=matchResult.diaSources[mask].reset_index(drop=True),
109 unAssocDiaSources=matchResult.diaSources[~mask].reset_index(drop=True),
110 nUpdatedDiaObjects=matchResult.nUpdatedDiaObjects,
111 nUnassociatedDiaObjects=matchResult.nUnassociatedDiaObjects)
113 def check_dia_source_radec(self, dia_sources):
114 """Check that all DiaSources have non-NaN values for RA/DEC.
116 If one or more DiaSources are found to have NaN values, throw a
117 warning to the log with the ids of the offending sources. Drop them
118 from the table.
120 Parameters
121 ----------
122 dia_sources : `pandas.DataFrame`
123 Input DiaSources to check for NaN values.
125 Returns
126 -------
127 trimmed_sources : `pandas.DataFrame`
128 DataFrame of DiaSources trimmed of all entries with NaN values for
129 RA/DEC.
130 """
131 nan_mask = (dia_sources.loc[:, "ra"].isnull()
132 | dia_sources.loc[:, "dec"].isnull())
133 if np.any(nan_mask):
134 nan_idxs = np.argwhere(nan_mask.to_numpy()).flatten()
135 for nan_idx in nan_idxs:
136 self.log.warning(
137 "DiaSource %i has NaN value for RA/DEC, "
138 "dropping from association." %
139 dia_sources.loc[nan_idx, "diaSourceId"])
140 dia_sources = dia_sources[~nan_mask]
141 return dia_sources
143 @timeMethod
144 def associate_sources(self, dia_objects, dia_sources):
145 """Associate the input DIASources with the catalog of DIAObjects.
147 DiaObject DataFrame must be indexed on ``diaObjectId``.
149 Parameters
150 ----------
151 dia_objects : `pandas.DataFrame`
152 Catalog of DIAObjects to attempt to associate the input
153 DIASources into.
154 dia_sources : `pandas.DataFrame`
155 DIASources to associate into the DIAObjectCollection.
157 Returns
158 -------
159 result : `lsst.pipe.base.Struct`
160 Results struct with components.
162 - ``diaSources`` : Full set of diaSources both matched and not.
163 (`pandas.DataFrame`)
164 - ``nUpdatedDiaObjects`` : Number of DiaObjects that were
165 associated. (`int`)
166 - ``nUnassociatedDiaObjects`` : Number of DiaObjects that were
167 not matched a new DiaSource. (`int`)
168 """
169 scores = self.score(
170 dia_objects, dia_sources,
171 self.config.maxDistArcSeconds * geom.arcseconds)
172 match_result = self.match(dia_objects, dia_sources, scores)
174 return match_result
176 @timeMethod
177 def score(self, dia_objects, dia_sources, max_dist):
178 """Compute a quality score for each dia_source/dia_object pair
179 between this catalog of DIAObjects and the input DIASource catalog.
181 ``max_dist`` sets maximum separation in arcseconds to consider a
182 dia_source a possible match to a dia_object. If the pair is
183 beyond this distance no score is computed.
185 Parameters
186 ----------
187 dia_objects : `pandas.DataFrame`
188 A contiguous catalog of DIAObjects to score against dia_sources.
189 dia_sources : `pandas.DataFrame`
190 A contiguous catalog of dia_sources to "score" based on distance
191 and (in the future) other metrics.
192 max_dist : `lsst.geom.Angle`
193 Maximum allowed distance to compute a score for a given DIAObject
194 DIASource pair.
196 Returns
197 -------
198 result : `lsst.pipe.base.Struct`
199 Results struct with components:
201 - ``scores``: array of floats of match quality updated DIAObjects
202 (array-like of `float`).
203 - ``obj_idxs``: indexes of the matched DIAObjects in the catalog.
204 (array-like of `int`)
205 - ``obj_ids``: array of floats of match quality updated DIAObjects
206 (array-like of `int`).
208 Default values for these arrays are
209 INF, -1, and -1 respectively for unassociated sources.
210 """
211 scores = np.full(len(dia_sources), np.inf, dtype=np.float64)
212 obj_idxs = np.full(len(dia_sources), -1, dtype=np.int64)
213 obj_ids = np.full(len(dia_sources), 0, dtype=np.int64)
215 if len(dia_objects) == 0:
216 return pipeBase.Struct(
217 scores=scores,
218 obj_idxs=obj_idxs,
219 obj_ids=obj_ids)
221 spatial_tree = self._make_spatial_tree(dia_objects)
223 max_dist_rad = max_dist.asRadians()
225 vectors = self._radec_to_xyz(dia_sources)
227 scores, obj_idxs = spatial_tree.query(
228 vectors,
229 distance_upper_bound=max_dist_rad)
230 matched_src_idxs = np.argwhere(np.isfinite(scores))
231 obj_ids[matched_src_idxs] = dia_objects.index.to_numpy()[
232 obj_idxs[matched_src_idxs]]
234 return pipeBase.Struct(
235 scores=scores,
236 obj_idxs=obj_idxs,
237 obj_ids=obj_ids)
239 def _make_spatial_tree(self, dia_objects):
240 """Create a searchable kd-tree the input dia_object positions.
242 Parameters
243 ----------
244 dia_objects : `pandas.DataFrame`
245 A catalog of DIAObjects to create the tree from.
247 Returns
248 -------
249 kd_tree : `scipy.spatical.cKDTree`
250 Searchable kd-tree created from the positions of the DIAObjects.
251 """
252 vectors = self._radec_to_xyz(dia_objects)
253 return cKDTree(vectors)
255 def _radec_to_xyz(self, catalog):
256 """Convert input ra/dec coordinates to spherical unit-vectors.
258 Parameters
259 ----------
260 catalog : `pandas.DataFrame`
261 Catalog to produce spherical unit-vector from.
263 Returns
264 -------
265 vectors : `numpy.ndarray`, (N, 3)
266 Output unit-vectors
267 """
268 ras = np.radians(catalog["ra"])
269 decs = np.radians(catalog["dec"])
270 vectors = np.empty((len(ras), 3))
272 sin_dec = np.sin(np.pi / 2 - decs)
273 vectors[:, 0] = sin_dec * np.cos(ras)
274 vectors[:, 1] = sin_dec * np.sin(ras)
275 vectors[:, 2] = np.cos(np.pi / 2 - decs)
277 return vectors
279 @timeMethod
280 def match(self, dia_objects, dia_sources, score_struct):
281 """Match DIAsources to DiaObjects given a score.
283 Parameters
284 ----------
285 dia_objects : `pandas.DataFrame`
286 A SourceCatalog of DIAObjects to associate to DIASources.
287 dia_sources : `pandas.DataFrame`
288 A contiguous catalog of dia_sources for which the set of scores
289 has been computed on with DIAObjectCollection.score.
290 score_struct : `lsst.pipe.base.Struct`
291 Results struct with components:
293 - ``"scores"``: array of floats of match quality
294 updated DIAObjects (array-like of `float`).
295 - ``"obj_ids"``: array of floats of match quality
296 updated DIAObjects (array-like of `int`).
297 - ``"obj_idxs"``: indexes of the matched DIAObjects in the catalog.
298 (array-like of `int`)
300 Default values for these arrays are
301 INF, -1 and -1 respectively for unassociated sources.
303 Returns
304 -------
305 result : `lsst.pipe.base.Struct`
306 Results struct with components.
308 - ``"diaSources"`` : Full set of diaSources both matched and not.
309 (`pandas.DataFrame`)
310 - ``"nUpdatedDiaObjects"`` : Number of DiaObjects that were
311 associated. (`int`)
312 - ``"nUnassociatedDiaObjects"`` : Number of DiaObjects that were
313 not matched a new DiaSource. (`int`)
314 """
315 n_previous_dia_objects = len(dia_objects)
316 used_dia_object = np.zeros(n_previous_dia_objects, dtype=bool)
317 used_dia_source = np.zeros(len(dia_sources), dtype=bool)
318 associated_dia_object_ids = np.zeros(len(dia_sources),
319 dtype=np.uint64)
320 n_updated_dia_objects = 0
322 # We sort from best match to worst to effectively perform a
323 # "handshake" match where both the DIASources and DIAObjects agree
324 # their the best match. By sorting this way, scores with NaN (those
325 # sources that have no match and will create new DIAObjects) will be
326 # placed at the end of the array.
327 score_args = score_struct.scores.argsort(axis=None)
328 for score_idx in score_args:
329 if not np.isfinite(score_struct.scores[score_idx]):
330 # Thanks to the sorting the rest of the sources will be
331 # NaN for their score. We therefore exit the loop to append
332 # sources to a existing DIAObject, leaving these for
333 # the loop creating new objects.
334 break
335 dia_obj_idx = score_struct.obj_idxs[score_idx]
336 if used_dia_object[dia_obj_idx]:
337 continue
338 used_dia_object[dia_obj_idx] = True
339 used_dia_source[score_idx] = True
340 obj_id = score_struct.obj_ids[score_idx]
341 associated_dia_object_ids[score_idx] = obj_id
342 dia_sources.loc[score_idx, "diaObjectId"] = obj_id
343 n_updated_dia_objects += 1
345 return pipeBase.Struct(
346 diaSources=dia_sources,
347 nUpdatedDiaObjects=n_updated_dia_objects,
348 nUnassociatedDiaObjects=(n_previous_dia_objects
349 - n_updated_dia_objects))