Coverage for python/lsst/pipe/tasks/simpleAssociation.py: 23%

82 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-07-21 03:37 -0700

1# This file is part of pipe_tasks. 

2 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21# 

22 

23"""Simple association algorithm for DRP. 

24Adapted from http://github.com/LSSTDESC/dia_pipe 

25""" 

26 

27import numpy as np 

28import pandas as pd 

29 

30import lsst.afw.table as afwTable 

31import lsst.geom as geom 

32import lsst.pex.config as pexConfig 

33import lsst.pipe.base as pipeBase 

34from lsst.obs.base import ExposureIdInfo 

35 

36from .associationUtils import query_disc, eq2xyz, toIndex 

37 

38 

39class SimpleAssociationConfig(pexConfig.Config): 

40 """Configuration parameters for the SimpleAssociationTask 

41 """ 

42 tolerance = pexConfig.Field( 

43 dtype=float, 

44 doc='maximum distance to match sources together in arcsec', 

45 default=0.5 

46 ) 

47 nside = pexConfig.Field( 

48 dtype=int, 

49 doc='Healpix nside value used for indexing', 

50 default=2**18, 

51 ) 

52 

53 

54class SimpleAssociationTask(pipeBase.Task): 

55 """Construct DiaObjects from a DataFrame of DIASources by spatially 

56 associating the sources. 

57 

58 Represents a simple, brute force algorithm, 2-way matching of DiaSources 

59 into. DiaObjects. Algorithm picks the nearest, first match within the 

60 matching radius of a DiaObject to associate a source to for simplicity. 

61 """ 

62 ConfigClass = SimpleAssociationConfig 

63 _DefaultName = "simpleAssociation" 

64 

65 def run(self, diaSources, tractPatchId, skymapBits): 

66 """Associate DiaSources into a collection of DiaObjects using a 

67 brute force matching algorithm. 

68 

69 Reproducible is for the same input data is assured by ordering the 

70 DiaSource data by ccdVisit ordering. 

71 

72 Parameters 

73 ---------- 

74 diaSources : `pandas.DataFrame` 

75 DiaSources grouped by CcdVisitId to spatially associate into 

76 DiaObjects. 

77 tractPatchId : `int` 

78 Unique identifier for the tract patch. 

79 skymapBits : `int` 

80 Maximum number of bits used the ``tractPatchId`` integer 

81 identifier. 

82 

83 Returns 

84 ------- 

85 results : `lsst.pipe.base.Struct` 

86 Results struct with attributes: 

87 

88 ``assocDiaSources`` 

89 Table of DiaSources with updated values for the DiaObjects 

90 they are spatially associated to (`pandas.DataFrame`). 

91 ``diaObjects`` 

92 Table of DiaObjects from matching DiaSources 

93 (`pandas.DataFrame`). 

94 

95 """ 

96 

97 # Expected indexes include diaSourceId or meaningless range index 

98 # If meaningless range index, drop it, else keep it. 

99 doDropIndex = diaSources.index.names[0] is None 

100 diaSources.reset_index(inplace=True, drop=doDropIndex) 

101 

102 # Sort by ccdVisit and diaSourceId to get a reproducible ordering for 

103 # the association. 

104 diaSources.set_index(["ccdVisitId", "diaSourceId"], inplace=True) 

105 

106 # Empty lists to store matching and location data. 

107 diaObjectCat = [] 

108 diaObjectCoords = [] 

109 healPixIndices = [] 

110 

111 # Create Id factory and catalog for creating DiaObjectIds. 

112 exposureIdInfo = ExposureIdInfo(tractPatchId, skymapBits) 

113 idFactory = exposureIdInfo.makeSourceIdFactory() 

114 idCat = afwTable.SourceCatalog( 

115 afwTable.SourceTable.make(afwTable.SourceTable.makeMinimalSchema(), 

116 idFactory)) 

117 

118 for ccdVisit in diaSources.index.levels[0]: 

119 # For the first ccdVisit, just copy the DiaSource info into the 

120 # diaObject data to create the first set of Objects. 

121 ccdVisitSources = diaSources.loc[ccdVisit] 

122 if len(diaObjectCat) == 0: 

123 for diaSourceId, diaSrc in ccdVisitSources.iterrows(): 

124 self.addNewDiaObject(diaSrc, 

125 diaSources, 

126 ccdVisit, 

127 diaSourceId, 

128 diaObjectCat, 

129 idCat, 

130 diaObjectCoords, 

131 healPixIndices) 

132 continue 

133 # Temp list to store DiaObjects already used for this ccdVisit. 

134 usedMatchIndicies = [] 

135 # Run over subsequent data. 

136 for diaSourceId, diaSrc in ccdVisitSources.iterrows(): 

137 # Find matches. 

138 matchResult = self.findMatches(diaSrc["ra"], 

139 diaSrc["decl"], 

140 2*self.config.tolerance, 

141 healPixIndices, 

142 diaObjectCat) 

143 dists = matchResult.dists 

144 matches = matchResult.matches 

145 # Create a new DiaObject if no match found. 

146 if dists is None: 

147 self.addNewDiaObject(diaSrc, 

148 diaSources, 

149 ccdVisit, 

150 diaSourceId, 

151 diaObjectCat, 

152 idCat, 

153 diaObjectCoords, 

154 healPixIndices) 

155 continue 

156 # If matched, update catalogs and arrays. 

157 if np.min(dists) < np.deg2rad(self.config.tolerance/3600): 

158 matchDistArg = np.argmin(dists) 

159 matchIndex = matches[matchDistArg] 

160 # Test to see if the DiaObject has been used. 

161 if np.isin([matchIndex], usedMatchIndicies).sum() < 1: 

162 self.updateCatalogs(matchIndex, 

163 diaSrc, 

164 diaSources, 

165 ccdVisit, 

166 diaSourceId, 

167 diaObjectCat, 

168 diaObjectCoords, 

169 healPixIndices) 

170 usedMatchIndicies.append(matchIndex) 

171 # If the matched DiaObject has already been used, create a 

172 # new DiaObject for this DiaSource. 

173 else: 

174 self.addNewDiaObject(diaSrc, 

175 diaSources, 

176 ccdVisit, 

177 diaSourceId, 

178 diaObjectCat, 

179 idCat, 

180 diaObjectCoords, 

181 healPixIndices) 

182 # Create new DiaObject if no match found within the matching 

183 # tolerance. 

184 else: 

185 self.addNewDiaObject(diaSrc, 

186 diaSources, 

187 ccdVisit, 

188 diaSourceId, 

189 diaObjectCat, 

190 idCat, 

191 diaObjectCoords, 

192 healPixIndices) 

193 

194 # Drop indices before returning associated diaSource catalog. 

195 diaSources.reset_index(inplace=True) 

196 diaSources.set_index("diaSourceId", inplace=True, verify_integrity=True) 

197 

198 objs = diaObjectCat if diaObjectCat else np.array([], dtype=[('diaObjectId', 'int64'), 

199 ('ra', 'float64'), 

200 ('decl', 'float64'), 

201 ('nDiaSources', 'int64')]) 

202 diaObjects = pd.DataFrame(data=objs) 

203 

204 if "diaObjectId" in diaObjects.columns: 

205 diaObjects.set_index("diaObjectId", inplace=True, verify_integrity=True) 

206 

207 return pipeBase.Struct( 

208 assocDiaSources=diaSources, 

209 diaObjects=diaObjects) 

210 

211 def addNewDiaObject(self, 

212 diaSrc, 

213 diaSources, 

214 ccdVisit, 

215 diaSourceId, 

216 diaObjCat, 

217 idCat, 

218 diaObjCoords, 

219 healPixIndices): 

220 """Create a new DiaObject and append its data. 

221 

222 Parameters 

223 ---------- 

224 diaSrc : `pandas.Series` 

225 Full unassociated DiaSource to create a DiaObject from. 

226 diaSources : `pandas.DataFrame` 

227 DiaSource catalog to update information in. The catalog is 

228 modified in place. 

229 ccdVisit : `int` 

230 Unique identifier of the ccdVisit where ``diaSrc`` was observed. 

231 diaSourceId : `int` 

232 Unique identifier of the DiaSource. 

233 diaObjectCat : `list` of `dict`s 

234 Catalog of diaObjects to append the new object o. 

235 idCat : `lsst.afw.table.SourceCatalog` 

236 Catalog with the IdFactory used to generate unique DiaObject 

237 identifiers. 

238 diaObjectCoords : `list` of `list`s of `lsst.geom.SpherePoint`s 

239 Set of coordinates of DiaSource locations that make up the 

240 DiaObject average coordinate. 

241 healPixIndices : `list` of `int`s 

242 HealPix indices representing the locations of each currently 

243 existing DiaObject. 

244 """ 

245 hpIndex = toIndex(self.config.nside, 

246 diaSrc["ra"], 

247 diaSrc["decl"]) 

248 healPixIndices.append(hpIndex) 

249 

250 sphPoint = geom.SpherePoint(diaSrc["ra"], 

251 diaSrc["decl"], 

252 geom.degrees) 

253 diaObjCoords.append([sphPoint]) 

254 

255 diaObjId = idCat.addNew().get("id") 

256 diaObjCat.append(self.createDiaObject(diaObjId, 

257 diaSrc["ra"], 

258 diaSrc["decl"])) 

259 diaSources.loc[(ccdVisit, diaSourceId), "diaObjectId"] = diaObjId 

260 

261 def updateCatalogs(self, 

262 matchIndex, 

263 diaSrc, 

264 diaSources, 

265 ccdVisit, 

266 diaSourceId, 

267 diaObjCat, 

268 diaObjCoords, 

269 healPixIndices): 

270 """Update DiaObject and DiaSource values after an association. 

271 

272 Parameters 

273 ---------- 

274 matchIndex : `int` 

275 Array index location of the DiaObject that ``diaSrc`` was 

276 associated to. 

277 diaSrc : `pandas.Series` 

278 Full unassociated DiaSource to create a DiaObject from. 

279 diaSources : `pandas.DataFrame` 

280 DiaSource catalog to update information in. The catalog is 

281 modified in place. 

282 ccdVisit : `int` 

283 Unique identifier of the ccdVisit where ``diaSrc`` was observed. 

284 diaSourceId : `int` 

285 Unique identifier of the DiaSource. 

286 diaObjectCat : `list` of `dict`s 

287 Catalog of diaObjects to append the new object o. 

288 diaObjectCoords : `list` of `list`s of `lsst.geom.SpherePoint`s 

289 Set of coordinates of DiaSource locations that make up the 

290 DiaObject average coordinate. 

291 healPixIndices : `list` of `int`s 

292 HealPix indices representing the locations of each currently 

293 existing DiaObject. 

294 """ 

295 # Update location and healPix index. 

296 sphPoint = geom.SpherePoint(diaSrc["ra"], 

297 diaSrc["decl"], 

298 geom.degrees) 

299 diaObjCoords[matchIndex].append(sphPoint) 

300 aveCoord = geom.averageSpherePoint(diaObjCoords[matchIndex]) 

301 diaObjCat[matchIndex]["ra"] = aveCoord.getRa().asDegrees() 

302 diaObjCat[matchIndex]["decl"] = aveCoord.getDec().asDegrees() 

303 nSources = diaObjCat[matchIndex]["nDiaSources"] 

304 diaObjCat[matchIndex]["nDiaSources"] = nSources + 1 

305 healPixIndices[matchIndex] = toIndex(self.config.nside, 

306 diaObjCat[matchIndex]["ra"], 

307 diaObjCat[matchIndex]["decl"]) 

308 # Update DiaObject Id that this source is now associated to. 

309 diaSources.loc[(ccdVisit, diaSourceId), "diaObjectId"] = \ 

310 diaObjCat[matchIndex]["diaObjectId"] 

311 

312 def findMatches(self, src_ra, src_dec, tol, hpIndices, diaObjs): 

313 """Search healPixels around DiaSource locations for DiaObjects. 

314 

315 Parameters 

316 ---------- 

317 src_ra : `float` 

318 DiaSource RA location. 

319 src_dec : `float` 

320 DiaSource Dec location. 

321 tol : `float` 

322 Size of annulus to convert to covering healPixels and search for 

323 DiaObjects. 

324 hpIndices : `list` of `int`s 

325 List of heal pix indices containing the DiaObjects in ``diaObjs``. 

326 diaObjs : `list` of `dict`s 

327 Catalog diaObjects to with full location information for comparing 

328 to DiaSources. 

329 

330 Returns 

331 ------- 

332 results : `lsst.pipe.base.Struct` 

333 Results struct containing 

334 

335 ``dists`` 

336 Array of distances between the current DiaSource diaObjects. 

337 (`numpy.ndarray` or `None`) 

338 ``matches`` 

339 Array of array indices of diaObjects this DiaSource matches to. 

340 (`numpy.ndarray` or `None`) 

341 """ 

342 match_indices = query_disc(self.config.nside, 

343 src_ra, 

344 src_dec, 

345 np.deg2rad(tol/3600.)) 

346 matchIndices = np.argwhere(np.isin(hpIndices, match_indices)).flatten() 

347 

348 if len(matchIndices) < 1: 

349 return pipeBase.Struct(dists=None, matches=None) 

350 

351 dists = np.array( 

352 [np.sqrt(np.sum((eq2xyz(src_ra, src_dec) 

353 - eq2xyz(diaObjs[match]["ra"], 

354 diaObjs[match]["decl"]))**2)) 

355 for match in matchIndices]) 

356 return pipeBase.Struct( 

357 dists=dists, 

358 matches=matchIndices) 

359 

360 def createDiaObject(self, objId, ra, decl): 

361 """Create a simple empty DiaObject with location and id information. 

362 

363 Parameters 

364 ---------- 

365 objId : `int` 

366 Unique ID for this new DiaObject. 

367 ra : `float` 

368 RA location of this DiaObject. 

369 decl : `float` 

370 Dec location of this DiaObject 

371 

372 Returns 

373 ------- 

374 DiaObject : `dict` 

375 Dictionary of values representing a DiaObject. 

376 """ 

377 new_dia_object = {"diaObjectId": objId, 

378 "ra": ra, 

379 "decl": decl, 

380 "nDiaSources": 1} 

381 return new_dia_object