Coverage for python/lsst/pipe/tasks/simpleAssociation.py: 20%

85 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 11:06 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Simple association algorithm for DRP. 

23Adapted from http://github.com/LSSTDESC/dia_pipe 

24""" 

25__all__ = ["SimpleAssociationConfig", "SimpleAssociationTask"] 

26 

27import numpy as np 

28import pandas as pd 

29 

30import lsst.afw.table as afwTable 

31import lsst.geom as geom 

32import lsst.pex.config as pexConfig 

33import lsst.pipe.base as pipeBase 

34from lsst.meas.base import IdGenerator 

35 

36from .associationUtils import query_disc, eq2xyz, toIndex 

37 

38 

39class SimpleAssociationConfig(pexConfig.Config): 

40 """Configuration parameters for the SimpleAssociationTask 

41 """ 

42 tolerance = pexConfig.Field( 

43 dtype=float, 

44 doc='maximum distance to match sources together in arcsec', 

45 default=0.5 

46 ) 

47 nside = pexConfig.Field( 

48 dtype=int, 

49 doc='Healpix nside value used for indexing', 

50 default=2**18, 

51 ) 

52 

53 

54class SimpleAssociationTask(pipeBase.Task): 

55 """Construct DiaObjects from a DataFrame of DIASources by spatially 

56 associating the sources. 

57 

58 Represents a simple, brute force algorithm, 2-way matching of DiaSources 

59 into. DiaObjects. Algorithm picks the nearest, first match within the 

60 matching radius of a DiaObject to associate a source to for simplicity. 

61 """ 

62 ConfigClass = SimpleAssociationConfig 

63 _DefaultName = "simpleAssociation" 

64 

65 def run(self, diaSources, idGenerator=None): 

66 """Associate DiaSources into a collection of DiaObjects using a 

67 brute force matching algorithm. 

68 

69 Reproducible for the same input data is assured by ordering the 

70 DiaSource data by visit,detector. 

71 

72 Parameters 

73 ---------- 

74 diaSources : `pandas.DataFrame` 

75 DiaSources in clusters of visit, detector to spatially associate 

76 into DiaObjects. 

77 idGenerator : `lsst.meas.base.IdGenerator`, optional 

78 Object that generates Object IDs and random number generator seeds. 

79 

80 Returns 

81 ------- 

82 results : `lsst.pipe.base.Struct` 

83 Results struct with attributes: 

84 

85 ``assocDiaSources`` 

86 Table of DiaSources with updated values for the DiaObjects 

87 they are spatially associated to (`pandas.DataFrame`). 

88 ``diaObjects`` 

89 Table of DiaObjects from matching DiaSources 

90 (`pandas.DataFrame`). 

91 """ 

92 # Expected indexes include diaSourceId or meaningless range index 

93 # If meaningless range index, drop it, else keep it. 

94 doDropIndex = diaSources.index.names[0] is None 

95 diaSources.reset_index(inplace=True, drop=doDropIndex) 

96 

97 # Sort by visit, detector, and diaSourceId to get a reproducible 

98 # ordering for the association. Use a temporary combined visit,detector 

99 # to simplify multi-index operations below; we delete it at the end. 

100 diaSources[("visit,detector")] = list(zip(diaSources["visit"], diaSources["detector"])) 

101 diaSources.set_index([("visit,detector"), "diaSourceId"], inplace=True) 

102 

103 # Empty lists to store matching and location data. 

104 diaObjectCat = [] 

105 diaObjectCoords = [] 

106 healPixIndices = [] 

107 

108 # Create Id factory and catalog for creating DiaObjectIds. 

109 if idGenerator is None: 

110 idGenerator = IdGenerator() 

111 idCat = idGenerator.make_source_catalog(afwTable.SourceTable.makeMinimalSchema()) 

112 

113 for visit, detector in diaSources.index.levels[0]: 

114 # For the first visit,detector, just copy the DiaSource info into the 

115 # diaObject data to create the first set of Objects. 

116 orderedSources = diaSources.loc[(visit, detector)] 

117 if len(diaObjectCat) == 0: 

118 for diaSourceId, diaSrc in orderedSources.iterrows(): 

119 self.addNewDiaObject(diaSrc, 

120 diaSources, 

121 visit, 

122 detector, 

123 diaSourceId, 

124 diaObjectCat, 

125 idCat, 

126 diaObjectCoords, 

127 healPixIndices) 

128 continue 

129 # Temp list to store DiaObjects already used for this visit, detector. 

130 usedMatchIndicies = [] 

131 # Run over subsequent data. 

132 for diaSourceId, diaSrc in orderedSources.iterrows(): 

133 # Find matches. 

134 matchResult = self.findMatches(diaSrc["ra"], 

135 diaSrc["dec"], 

136 2*self.config.tolerance, 

137 healPixIndices, 

138 diaObjectCat) 

139 dists = matchResult.dists 

140 matches = matchResult.matches 

141 # Create a new DiaObject if no match found. 

142 if dists is None: 

143 self.addNewDiaObject(diaSrc, 

144 diaSources, 

145 visit, 

146 detector, 

147 diaSourceId, 

148 diaObjectCat, 

149 idCat, 

150 diaObjectCoords, 

151 healPixIndices) 

152 continue 

153 # If matched, update catalogs and arrays. 

154 if np.min(dists) < np.deg2rad(self.config.tolerance/3600): 

155 matchDistArg = np.argmin(dists) 

156 matchIndex = matches[matchDistArg] 

157 # Test to see if the DiaObject has been used. 

158 if np.isin([matchIndex], usedMatchIndicies).sum() < 1: 

159 self.updateCatalogs(matchIndex, 

160 diaSrc, 

161 diaSources, 

162 visit, 

163 detector, 

164 diaSourceId, 

165 diaObjectCat, 

166 diaObjectCoords, 

167 healPixIndices) 

168 usedMatchIndicies.append(matchIndex) 

169 # If the matched DiaObject has already been used, create a 

170 # new DiaObject for this DiaSource. 

171 else: 

172 self.addNewDiaObject(diaSrc, 

173 diaSources, 

174 visit, 

175 detector, 

176 diaSourceId, 

177 diaObjectCat, 

178 idCat, 

179 diaObjectCoords, 

180 healPixIndices) 

181 # Create new DiaObject if no match found within the matching 

182 # tolerance. 

183 else: 

184 self.addNewDiaObject(diaSrc, 

185 diaSources, 

186 visit, 

187 detector, 

188 diaSourceId, 

189 diaObjectCat, 

190 idCat, 

191 diaObjectCoords, 

192 healPixIndices) 

193 

194 # Drop indices before returning associated diaSource catalog. 

195 diaSources.reset_index(inplace=True) 

196 del diaSources["visit,detector"] 

197 diaSources.set_index("diaSourceId", inplace=True, verify_integrity=True) 

198 

199 objs = diaObjectCat if diaObjectCat else np.array([], dtype=[('diaObjectId', 'int64'), 

200 ('ra', 'float64'), 

201 ('dec', 'float64'), 

202 ('nDiaSources', 'int64')]) 

203 diaObjects = pd.DataFrame(data=objs) 

204 

205 if "diaObjectId" in diaObjects.columns: 

206 diaObjects.set_index("diaObjectId", inplace=True, verify_integrity=True) 

207 

208 return pipeBase.Struct( 

209 assocDiaSources=diaSources, 

210 diaObjects=diaObjects) 

211 

212 def addNewDiaObject(self, 

213 diaSrc, 

214 diaSources, 

215 visit, 

216 detector, 

217 diaSourceId, 

218 diaObjCat, 

219 idCat, 

220 diaObjCoords, 

221 healPixIndices): 

222 """Create a new DiaObject and append its data. 

223 

224 Parameters 

225 ---------- 

226 diaSrc : `pandas.Series` 

227 Full unassociated DiaSource to create a DiaObject from. 

228 diaSources : `pandas.DataFrame` 

229 DiaSource catalog to update information in. The catalog is 

230 modified in place. Must be indexed on: 

231 `(visit, detector), diaSourceId`. 

232 visit, detector : `int` 

233 Visit and detector ids where ``diaSrc`` was observed. 

234 diaSourceId : `int` 

235 Unique identifier of the DiaSource. 

236 diaObjectCat : `list` of `dict`s 

237 Catalog of diaObjects to append the new object o. 

238 idCat : `lsst.afw.table.SourceCatalog` 

239 Catalog with the IdFactory used to generate unique DiaObject 

240 identifiers. 

241 diaObjectCoords : `list` of `list`s of `lsst.geom.SpherePoint`s 

242 Set of coordinates of DiaSource locations that make up the 

243 DiaObject average coordinate. 

244 healPixIndices : `list` of `int`s 

245 HealPix indices representing the locations of each currently 

246 existing DiaObject. 

247 """ 

248 hpIndex = toIndex(self.config.nside, 

249 diaSrc["ra"], 

250 diaSrc["dec"]) 

251 healPixIndices.append(hpIndex) 

252 

253 sphPoint = geom.SpherePoint(diaSrc["ra"], 

254 diaSrc["dec"], 

255 geom.degrees) 

256 diaObjCoords.append([sphPoint]) 

257 

258 diaObjId = idCat.addNew().get("id") 

259 diaObjCat.append(self.createDiaObject(diaObjId, 

260 diaSrc["ra"], 

261 diaSrc["dec"])) 

262 diaSources.loc[((visit, detector), diaSourceId), "diaObjectId"] = diaObjId 

263 

264 def updateCatalogs(self, 

265 matchIndex, 

266 diaSrc, 

267 diaSources, 

268 visit, 

269 detector, 

270 diaSourceId, 

271 diaObjCat, 

272 diaObjCoords, 

273 healPixIndices): 

274 """Update DiaObject and DiaSource values after an association. 

275 

276 Parameters 

277 ---------- 

278 matchIndex : `int` 

279 Array index location of the DiaObject that ``diaSrc`` was 

280 associated to. 

281 diaSrc : `pandas.Series` 

282 Full unassociated DiaSource to create a DiaObject from. 

283 diaSources : `pandas.DataFrame` 

284 DiaSource catalog to update information in. The catalog is 

285 modified in place. Must be indexed on: 

286 `(visit, detector), diaSourceId`. 

287 visit, detector : `int` 

288 Visit and detector ids where ``diaSrc`` was observed. 

289 diaSourceId : `int` 

290 Unique identifier of the DiaSource. 

291 diaObjectCat : `list` of `dict`s 

292 Catalog of diaObjects to append the new object o. 

293 diaObjectCoords : `list` of `list`s of `lsst.geom.SpherePoint`s 

294 Set of coordinates of DiaSource locations that make up the 

295 DiaObject average coordinate. 

296 healPixIndices : `list` of `int`s 

297 HealPix indices representing the locations of each currently 

298 existing DiaObject. 

299 """ 

300 # Update location and healPix index. 

301 sphPoint = geom.SpherePoint(diaSrc["ra"], 

302 diaSrc["dec"], 

303 geom.degrees) 

304 diaObjCoords[matchIndex].append(sphPoint) 

305 aveCoord = geom.averageSpherePoint(diaObjCoords[matchIndex]) 

306 diaObjCat[matchIndex]["ra"] = aveCoord.getRa().asDegrees() 

307 diaObjCat[matchIndex]["dec"] = aveCoord.getDec().asDegrees() 

308 nSources = diaObjCat[matchIndex]["nDiaSources"] 

309 diaObjCat[matchIndex]["nDiaSources"] = nSources + 1 

310 healPixIndices[matchIndex] = toIndex(self.config.nside, 

311 diaObjCat[matchIndex]["ra"], 

312 diaObjCat[matchIndex]["dec"]) 

313 # Update DiaObject Id that this source is now associated to. 

314 diaSources.loc[((visit, detector), diaSourceId), "diaObjectId"] = \ 

315 diaObjCat[matchIndex]["diaObjectId"] 

316 

317 def findMatches(self, src_ra, src_dec, tol, hpIndices, diaObjs): 

318 """Search healPixels around DiaSource locations for DiaObjects. 

319 

320 Parameters 

321 ---------- 

322 src_ra : `float` 

323 DiaSource RA location. 

324 src_dec : `float` 

325 DiaSource Dec location. 

326 tol : `float` 

327 Size of annulus to convert to covering healPixels and search for 

328 DiaObjects. 

329 hpIndices : `list` of `int`s 

330 List of heal pix indices containing the DiaObjects in ``diaObjs``. 

331 diaObjs : `list` of `dict`s 

332 Catalog diaObjects to with full location information for comparing 

333 to DiaSources. 

334 

335 Returns 

336 ------- 

337 results : `lsst.pipe.base.Struct` 

338 Results struct containing 

339 

340 ``dists`` 

341 Array of distances between the current DiaSource diaObjects. 

342 (`numpy.ndarray` or `None`) 

343 ``matches`` 

344 Array of array indices of diaObjects this DiaSource matches to. 

345 (`numpy.ndarray` or `None`) 

346 """ 

347 match_indices = query_disc(self.config.nside, 

348 src_ra, 

349 src_dec, 

350 np.deg2rad(tol/3600.)) 

351 matchIndices = np.argwhere(np.isin(hpIndices, match_indices)).flatten() 

352 

353 if len(matchIndices) < 1: 

354 return pipeBase.Struct(dists=None, matches=None) 

355 

356 dists = np.array( 

357 [np.sqrt(np.sum((eq2xyz(src_ra, src_dec) 

358 - eq2xyz(diaObjs[match]["ra"], 

359 diaObjs[match]["dec"]))**2)) 

360 for match in matchIndices]) 

361 return pipeBase.Struct( 

362 dists=dists, 

363 matches=matchIndices) 

364 

365 def createDiaObject(self, objId, ra, dec): 

366 """Create a simple empty DiaObject with location and id information. 

367 

368 Parameters 

369 ---------- 

370 objId : `int` 

371 Unique ID for this new DiaObject. 

372 ra : `float` 

373 RA location of this DiaObject. 

374 dec : `float` 

375 Dec location of this DiaObject 

376 

377 Returns 

378 ------- 

379 DiaObject : `dict` 

380 Dictionary of values representing a DiaObject. 

381 """ 

382 new_dia_object = {"diaObjectId": objId, 

383 "ra": ra, 

384 "dec": dec, 

385 "nDiaSources": 1} 

386 return new_dia_object