lsst.pipe.tasks  21.0.0-120-g57749b33+77c36da417
simpleAssociation.py
Go to the documentation of this file.
1 # This file is part of pipe_tasks.
2 
3 # Developed for the LSST Data Management System.
4 # This product includes software developed by the LSST Project
5 # (https://www.lsst.org).
6 # See the COPYRIGHT file at the top-level directory of this distribution
7 # for details of code ownership.
8 
9 # This program is free software: you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation, either version 3 of the License, or
12 # (at your option) any later version.
13 
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
18 
19 # You should have received a copy of the GNU General Public License
20 # along with this program. If not, see <https://www.gnu.org/licenses/>.
21 #
22 
23 """Simple association algorithm for DRP.
24 Adapted from http://github.com/LSSTDESC/dia_pipe
25 """
26 
27 import numpy as np
28 import pandas as pd
29 
30 import lsst.afw.table as afwTable
31 import lsst.geom as geom
32 import lsst.pex.config as pexConfig
33 import lsst.pipe.base as pipeBase
34 
35 from .associationUtils import query_disc, eq2xyz, toIndex
36 
37 
38 class SimpleAssociationConfig(pexConfig.Config):
39  """Configuration parameters for the SimpleAssociationTask
40  """
41  tolerance = pexConfig.Field(
42  dtype=float,
43  doc='maximum distance to match sources together in arcsec',
44  default=0.5
45  )
46  nside = pexConfig.Field(
47  dtype=int,
48  doc='Healpix nside value used for indexing',
49  default=2**18,
50  )
51 
52 
53 class SimpleAssociationTask(pipeBase.Task):
54  """Construct DiaObjects from a DataFrame of DIASources by spatially
55  associating the sources.
56 
57  Represents a simple, brute force algorithm, 2-way matching of DiaSources
58  into. DiaObjects. Algorithm picks the nearest, first match within the
59  matching radius of a DiaObject to associate a source to for simplicity.
60  """
61  ConfigClass = SimpleAssociationConfig
62  _DefaultName = "simpleAssociation"
63 
64  def run(self, diaSources, tractPatchId, skymapBits):
65  """Associate DiaSources into a collection of DiaObjects using a
66  brute force matching algorithm.
67 
68  Reproducible is for the same input data is assured by ordering the
69  DiaSource data by ccdVisit ordering.
70 
71  Parameters
72  ----------
73  diaSources : `pandas.DataFrame`
74  DiaSources grouped by CcdVisitId to spatially associate into
75  DiaObjects.
76  tractPatchId : `int`
77  Unique identifier for the tract patch.
78  skymapBits : `int`
79  Maximum number of bits used the ``tractPatchId`` integer
80  identifier.
81 
82  Returns
83  -------
84  results : `lsst.pipe.base.Struct`
85  Results struct with attributes:
86 
87  ``assocDiaSources``
88  Table of DiaSources with updated values for the DiaObjects
89  they are spatially associated to (`pandas.DataFrame`).
90  ``diaObjects``
91  Table of DiaObjects from matching DiaSources
92  (`pandas.DataFrame`).
93 
94  """
95  # Sort by ccdVisit and diaSourceId to get a reproducible ordering for
96  # the association.
97  diaSources.set_index(["ccdVisitId", "diaSourceId"], inplace=True)
98 
99  # Empty lists to store matching and location data.
100  diaObjectCat = []
101  diaObjectCoords = []
102  healPixIndices = []
103 
104  # Create Id factory and catalog for creating DiaObjectIds.
105  idFactory = afwTable.IdFactory.makeSource(tractPatchId,
106  64 - skymapBits)
107  idCat = afwTable.SourceCatalog(
108  afwTable.SourceTable.make(afwTable.SourceTable.makeMinimalSchema(),
109  idFactory))
110 
111  for ccdVisit in diaSources.index.levels[0]:
112  # For the first ccdVisit, just copy the DiaSource info into the
113  # diaObject data to create the first set of Objects.
114  ccdVisitSources = diaSources.loc[ccdVisit]
115  if len(diaObjectCat) == 0:
116  for diaSourceId, diaSrc in ccdVisitSources.iterrows():
117  self.addNewDiaObjectaddNewDiaObject(diaSrc,
118  diaSources,
119  ccdVisit,
120  diaSourceId,
121  diaObjectCat,
122  idCat,
123  diaObjectCoords,
124  healPixIndices)
125  continue
126  # Temp list to store DiaObjects already used for this ccdVisit.
127  usedMatchIndicies = []
128  # Run over subsequent data.
129  for diaSourceId, diaSrc in ccdVisitSources.iterrows():
130  # Find matches.
131  matchResult = self.findMatchesfindMatches(diaSrc["ra"],
132  diaSrc["decl"],
133  2*self.config.tolerance,
134  healPixIndices,
135  diaObjectCat)
136  dists = matchResult.dists
137  matches = matchResult.matches
138  # Create a new DiaObject if no match found.
139  if dists is None:
140  self.addNewDiaObjectaddNewDiaObject(diaSrc,
141  diaSources,
142  ccdVisit,
143  diaSourceId,
144  diaObjectCat,
145  idCat,
146  diaObjectCoords,
147  healPixIndices)
148  continue
149  # If matched, update catalogs and arrays.
150  if np.min(dists) < np.deg2rad(self.config.tolerance/3600):
151  matchDistArg = np.argmin(dists)
152  matchIndex = matches[matchDistArg]
153  # Test to see if the DiaObject has been used.
154  if np.isin([matchIndex], usedMatchIndicies).sum() < 1:
155  self.updateCatalogsupdateCatalogs(matchIndex,
156  diaSrc,
157  diaSources,
158  ccdVisit,
159  diaSourceId,
160  diaObjectCat,
161  diaObjectCoords,
162  healPixIndices)
163  usedMatchIndicies.append(matchIndex)
164  # If the matched DiaObject has already been used, create a
165  # new DiaObject for this DiaSource.
166  else:
167  self.addNewDiaObjectaddNewDiaObject(diaSrc,
168  diaSources,
169  ccdVisit,
170  diaSourceId,
171  diaObjectCat,
172  idCat,
173  diaObjectCoords,
174  healPixIndices)
175  # Create new DiaObject if no match found within the matching
176  # tolerance.
177  else:
178  self.addNewDiaObjectaddNewDiaObject(diaSrc,
179  diaSources,
180  ccdVisit,
181  diaSourceId,
182  diaObjectCat,
183  idCat,
184  diaObjectCoords,
185  healPixIndices)
186  # Drop indices before returning associated diaSource catalog.
187  diaSources.reset_index(inplace=True)
188 
189  return pipeBase.Struct(
190  assocDiaSources=diaSources,
191  diaObjects=pd.DataFrame(data=diaObjectCat))
192 
193  def addNewDiaObject(self,
194  diaSrc,
195  diaSources,
196  ccdVisit,
197  diaSourceId,
198  diaObjCat,
199  idCat,
200  diaObjCoords,
201  healPixIndices):
202  """Create a new DiaObject and append its data.
203 
204  Parameters
205  ----------
206  diaSrc : `pandas.Series`
207  Full unassociated DiaSource to create a DiaObject from.
208  diaSources : `pandas.DataFrame`
209  DiaSource catalog to update information in. The catalog is
210  modified in place.
211  ccdVisit : `int`
212  Unique identifier of the ccdVisit where ``diaSrc`` was observed.
213  diaSourceId : `int`
214  Unique identifier of the DiaSource.
215  diaObjectCat : `list` of `dict`s
216  Catalog of diaObjects to append the new object o.
217  idCat : `lsst.afw.table.SourceCatalog`
218  Catalog with the IdFactory used to generate unique DiaObject
219  identifiers.
220  diaObjectCoords : `list` of `list`s of `lsst.geom.SpherePoint`s
221  Set of coordinates of DiaSource locations that make up the
222  DiaObject average coordinate.
223  healPixIndices : `list` of `int`s
224  HealPix indices representing the locations of each currently
225  existing DiaObject.
226  """
227  hpIndex = toIndex(self.config.nside,
228  diaSrc["ra"],
229  diaSrc["decl"])
230  healPixIndices.append(hpIndex)
231 
232  sphPoint = geom.SpherePoint(diaSrc["ra"],
233  diaSrc["decl"],
234  geom.degrees)
235  diaObjCoords.append([sphPoint])
236 
237  diaObjId = idCat.addNew().get("id")
238  diaObjCat.append(self.createDiaObjectcreateDiaObject(diaObjId,
239  diaSrc["ra"],
240  diaSrc["decl"]))
241  diaSources.loc[(ccdVisit, diaSourceId), "diaObjectId"] = diaObjId
242 
243  def updateCatalogs(self,
244  matchIndex,
245  diaSrc,
246  diaSources,
247  ccdVisit,
248  diaSourceId,
249  diaObjCat,
250  diaObjCoords,
251  healPixIndices):
252  """Update DiaObject and DiaSource values after an association.
253 
254  Parameters
255  ----------
256  matchIndex : `int`
257  Array index location of the DiaObject that ``diaSrc`` was
258  associated to.
259  diaSrc : `pandas.Series`
260  Full unassociated DiaSource to create a DiaObject from.
261  diaSources : `pandas.DataFrame`
262  DiaSource catalog to update information in. The catalog is
263  modified in place.
264  ccdVisit : `int`
265  Unique identifier of the ccdVisit where ``diaSrc`` was observed.
266  diaSourceId : `int`
267  Unique identifier of the DiaSource.
268  diaObjectCat : `list` of `dict`s
269  Catalog of diaObjects to append the new object o.
270  diaObjectCoords : `list` of `list`s of `lsst.geom.SpherePoint`s
271  Set of coordinates of DiaSource locations that make up the
272  DiaObject average coordinate.
273  healPixIndices : `list` of `int`s
274  HealPix indices representing the locations of each currently
275  existing DiaObject.
276  """
277  # Update location and healPix index.
278  sphPoint = geom.SpherePoint(diaSrc["ra"],
279  diaSrc["decl"],
280  geom.degrees)
281  diaObjCoords[matchIndex].append(sphPoint)
282  aveCoord = geom.averageSpherePoint(diaObjCoords[matchIndex])
283  diaObjCat[matchIndex]["ra"] = aveCoord.getRa().asDegrees()
284  diaObjCat[matchIndex]["decl"] = aveCoord.getDec().asDegrees()
285  nSources = diaObjCat[matchIndex]["nDiaSources"]
286  diaObjCat[matchIndex]["nDiaSources"] = nSources + 1
287  healPixIndices[matchIndex] = toIndex(self.config.nside,
288  diaObjCat[matchIndex]["ra"],
289  diaObjCat[matchIndex]["decl"])
290  # Update DiaObject Id that this source is now associated to.
291  diaSources.loc[(ccdVisit, diaSourceId), "diaObjectId"] = \
292  diaObjCat[matchIndex]["diaObjectId"]
293 
294  def findMatches(self, src_ra, src_dec, tol, hpIndices, diaObjs):
295  """Search healPixels around DiaSource locations for DiaObjects.
296 
297  Parameters
298  ----------
299  src_ra : `float`
300  DiaSource RA location.
301  src_dec : `float`
302  DiaSource Dec location.
303  tol : `float`
304  Size of annulus to convert to covering healPixels and search for
305  DiaObjects.
306  hpIndices : `list` of `int`s
307  List of heal pix indices containing the DiaObjects in ``diaObjs``.
308  diaObjs : `list` of `dict`s
309  Catalog diaObjects to with full location information for comparing
310  to DiaSources.
311 
312  Returns
313  -------
314  results : `lsst.pipe.base.Struct`
315  Results struct containing
316 
317  ``dists``
318  Array of distances between the current DiaSource diaObjects.
319  (`numpy.ndarray` or `None`)
320  ``matches``
321  Array of array indices of diaObjects this DiaSource matches to.
322  (`numpy.ndarray` or `None`)
323  """
324  match_indices = query_disc(self.config.nside,
325  src_ra,
326  src_dec,
327  np.deg2rad(tol/3600.))
328  matchIndices = np.argwhere(np.isin(hpIndices, match_indices)).flatten()
329 
330  if len(matchIndices) < 1:
331  return pipeBase.Struct(dists=None, matches=None)
332 
333  dists = np.array(
334  [np.sqrt(np.sum((eq2xyz(src_ra, src_dec)
335  - eq2xyz(diaObjs[match]["ra"],
336  diaObjs[match]["decl"]))**2))
337  for match in matchIndices])
338  return pipeBase.Struct(
339  dists=dists,
340  matches=matchIndices)
341 
342  def createDiaObject(self, objId, ra, decl):
343  """Create a simple empty DiaObject with location and id information.
344 
345  Parameters
346  ----------
347  objId : `int`
348  Unique ID for this new DiaObject.
349  ra : `float`
350  RA location of this DiaObject.
351  decl : `float`
352  Dec location of this DiaObject
353 
354  Returns
355  -------
356  DiaObject : `dict`
357  Dictionary of values representing a DiaObject.
358  """
359  new_dia_object = {"diaObjectId": objId,
360  "ra": ra,
361  "decl": decl,
362  "nDiaSources": 1}
363  return new_dia_object
def updateCatalogs(self, matchIndex, diaSrc, diaSources, ccdVisit, diaSourceId, diaObjCat, diaObjCoords, healPixIndices)
def addNewDiaObject(self, diaSrc, diaSources, ccdVisit, diaSourceId, diaObjCat, idCat, diaObjCoords, healPixIndices)
def findMatches(self, src_ra, src_dec, tol, hpIndices, diaObjs)
def run(self, diaSources, tractPatchId, skymapBits)
def query_disc(nside, ra, dec, max_rad, min_rad=0)