Coverage for python/lsst/afw/table/multiMatch.py: 11%

131 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-24 02:28 -0700

1# This file is part of afw. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21import collections.abc 

22 

23import numpy 

24 

25import lsst.geom 

26from ._schemaMapper import SchemaMapper 

27from ._table import CoordKey, SourceRecord 

28 

29 

30class MultiMatch: 

31 """Initialize a multi-catalog match. 

32 

33 Parameters 

34 ---------- 

35 schema : `lsst.afw.table.Schema` 

36 Schema shared by all catalogs to be included in the match. 

37 dataIdFormat : `dict` 

38 Set of name: type for all data ID keys (e.g. {"visit":int, 

39 "ccd":int}). 

40 coordField : `str`, optional 

41 Prefix for _ra and _dec fields that contain the 

42 coordinates to use for the match. 

43 idField : `str`, optional 

44 Name of the field in schema that contains unique object 

45 IDs. 

46 radius : `lsst.geom.Angle`, optional 

47 Maximum separation for a match. Defaults to 0.5 arcseconds. 

48 RecordClass : `lsst.afw.table.BaseRecord` 

49 Type of record to expect in catalogs to be matched. 

50 """ 

51 

52 def __init__(self, schema, dataIdFormat, coordField="coord", idField="id", radius=None, 

53 RecordClass=SourceRecord): 

54 if radius is None: 

55 radius = 0.5*lsst.geom.arcseconds 

56 elif not isinstance(radius, lsst.geom.Angle): 

57 raise ValueError("'radius' argument must be an Angle") 

58 self.radius = radius 

59 self.mapper = SchemaMapper(schema) 

60 self.mapper.addMinimalSchema(schema, True) 

61 self.coordKey = CoordKey(schema[coordField]) 

62 self.idKey = schema.find(idField).key 

63 self.dataIdKeys = {} 

64 outSchema = self.mapper.editOutputSchema() 

65 outSchema.setAliasMap(self.mapper.getInputSchema().getAliasMap()) 

66 self.objectKey = outSchema.addField( 

67 "object", type=numpy.int64, doc="Unique ID for joined sources") 

68 for name, dataType in dataIdFormat.items(): 

69 self.dataIdKeys[name] = outSchema.addField( 

70 name, type=dataType, doc="'%s' data ID component") 

71 # self.result will be a catalog containing the union of all matched records, with an 'object' ID 

72 # column that can be used to group matches. Sources that have ambiguous matches may appear 

73 # multiple times. 

74 self.result = None 

75 # self.reference will be a subset of self.result, with exactly one record for each group of matches 

76 # (we'll use the one from the first catalog matched into this group) 

77 # We'll use this to match against each subsequent catalog. 

78 self.reference = None 

79 # A set of ambiguous objects that we may want to ultimately remove from 

80 # the final merged catalog. 

81 self.ambiguous = set() 

82 # Table used to allocate new records for the ouput catalog. 

83 self.table = RecordClass.Table.make(self.mapper.getOutputSchema()) 

84 # Counter used to assign the next object ID 

85 self.nextObjId = 1 

86 

87 def makeRecord(self, inputRecord, dataId, objId): 

88 """Create a new result record from the given input record, using the 

89 given data ID and object ID to fill in additional columns. 

90 

91 Parameters 

92 ---------- 

93 inputRecord : `lsst.afw.table.source.sourceRecord` 

94 Record to use as the reference for the new result. 

95 dataId : `DataId` or `dict` 

96 Data id describing the data. 

97 objId : `int` 

98 Object id of the object to be added. 

99 

100 Returns 

101 ------- 

102 outputRecord : `lsst.afw.table.source.sourceRecord` 

103 Newly generated record. 

104 """ 

105 outputRecord = self.table.copyRecord(inputRecord, self.mapper) 

106 for name, key in self.dataIdKeys.items(): 

107 outputRecord.set(key, dataId[name]) 

108 outputRecord.set(self.objectKey, objId) 

109 return outputRecord 

110 

111 def add(self, catalog, dataId): 

112 """Add a new catalog to the match, corresponding to the given data ID. 

113 The new catalog is appended to the `self.result` and 

114 `self.reference` catalogs. 

115 

116 Parameters 

117 ---------- 

118 catalog : `lsst.afw.table.base.Catalog` 

119 Catalog to be added to the match result. 

120 dataId : `DataId` or `dict` 

121 Data id for the catalog to be added. 

122 """ 

123 if self.result is None: 

124 self.result = self.table.Catalog(self.table) 

125 for record in catalog: 

126 self.result.append(self.makeRecord( 

127 record, dataId, objId=self.nextObjId)) 

128 self.nextObjId += 1 

129 self.reference = self.result.copy(deep=False) 

130 return 

131 catalog.sort(self.idKey) # pre-sort for speedy by-id access later. 

132 # Will remove from this set as objects are matched. 

133 unmatchedIds = {record.get(self.idKey) for record in catalog} 

134 # Temporary dict mapping new source ID to a set of associated objects. 

135 newToObj = {} 

136 matches = lsst.afw.table.matchRaDec(self.reference, catalog, self.radius) 

137 matchedRefIds = set() 

138 matchedCatIds = set() 

139 for refRecord, newRecord, distance in matches: 

140 objId = refRecord.get(self.objectKey) 

141 if objId in matchedRefIds: 

142 # We've already matched this object against another new source, 

143 # mark it as ambiguous. 

144 self.ambiguous.add(objId) 

145 matchedRefIds.add(objId) 

146 if newRecord.get(self.idKey) in matchedCatIds: 

147 # We've already matched this new source to one or more other objects 

148 # Mark all involved objects as ambiguous 

149 self.ambiguous.add(objId) 

150 self.ambiguous |= newToObj.get(newRecord.get(self.idKey), set()) 

151 matchedCatIds.add(newRecord.get(self.idKey)) 

152 unmatchedIds.discard(newRecord.get(self.idKey)) 

153 # Populate the newToObj dict (setdefault trick is an idiom for 

154 # appending to a dict-of-sets) 

155 newToObj.setdefault(newRecord.get(self.idKey), set()).add(objId) 

156 # Add a new result record for this match. 

157 self.result.append(self.makeRecord(newRecord, dataId, objId)) 

158 # Add any unmatched sources from the new catalog as new objects to both 

159 # the joined result catalog and the reference catalog. 

160 for objId in unmatchedIds: 

161 newRecord = catalog.find(objId, self.idKey) 

162 resultRecord = self.makeRecord(newRecord, dataId, self.nextObjId) 

163 self.nextObjId += 1 

164 self.result.append(resultRecord) 

165 self.reference.append(resultRecord) 

166 

167 def finish(self, removeAmbiguous=True): 

168 """Return the final match catalog, after sorting it by object, copying 

169 it to ensure contiguousness, and optionally removing ambiguous 

170 matches. 

171 

172 After calling finish(), the in-progress state of the matcher 

173 is returned to the state it was just after construction, with 

174 the exception of the object ID counter (which is not reset). 

175 

176 Parameters 

177 ---------- 

178 removeAmbiguous : `bool`, optional 

179 Should ambiguous matches be removed from the match 

180 catalog? Defaults to True. 

181 

182 Returns 

183 ------- 

184 result : `lsst.afw.table.base.Catalog` 

185 Final match catalog, sorted by object. 

186 """ 

187 if removeAmbiguous: 

188 result = self.table.Catalog(self.table) 

189 for record in self.result: 

190 if record.get(self.objectKey) not in self.ambiguous: 

191 result.append(record) 

192 else: 

193 result = self.result 

194 result.sort(self.objectKey) 

195 result = result.copy(deep=True) 

196 self.result = None 

197 self.reference = None 

198 self.ambiguous = set() 

199 return result 

200 

201 

202class GroupView(collections.abc.Mapping): 

203 """A mapping (i.e. dict-like object) that provides convenient 

204 operations on the concatenated catalogs returned by a MultiMatch 

205 object. 

206 

207 A GroupView provides access to a catalog of grouped objects, in 

208 which the grouping is indicated by a field for which all records 

209 in a group have the same value. Once constructed, it allows 

210 operations similar to those supported by SQL "GROUP BY", such as 

211 filtering and aggregate calculation. 

212 

213 Parameters 

214 ---------- 

215 schema : `lsst.afw.table.Schema` 

216 Catalog schema to use for the grouped object catalog. 

217 ids : `List` 

218 List of identifying keys for the groups in the catalog. 

219 groups : `List` 

220 List of catalog subsets associated with each key in ids. 

221 """ 

222 

223 @classmethod 

224 def build(cls, catalog, groupField="object"): 

225 """Construct a GroupView from a concatenated catalog. 

226 

227 Parameters 

228 ---------- 

229 catalog : `lsst.afw.table.base.Catalog` 

230 Input catalog, containing records grouped by a field in 

231 which all records in the same group have the same value. 

232 Must be sorted by the group field. 

233 groupField : `str`, optional 

234 Name or Key for the field that indicates groups. Defaults 

235 to "object". 

236 

237 Returns 

238 ------- 

239 groupCatalog : `lsst.afw.table.multiMatch.GroupView` 

240 Constructed GroupView from the input concatenated catalog. 

241 """ 

242 groupKey = catalog.schema.find(groupField).key 

243 ids, indices = numpy.unique(catalog.get(groupKey), return_index=True) 

244 groups = numpy.zeros(len(ids), dtype=object) 

245 ends = list(indices[1:]) + [len(catalog)] 

246 for n, (i1, i2) in enumerate(zip(indices, ends)): 

247 groups[n] = catalog[i1:i2] 

248 assert (groups[n].get(groupKey) == ids[n]).all() 

249 return cls(catalog.schema, ids, groups) 

250 

251 def __init__(self, schema, ids, groups): 

252 self.schema = schema 

253 self.ids = ids 

254 self.groups = groups 

255 self.count = sum(len(cat) for cat in self.groups) 

256 

257 def __len__(self): 

258 return len(self.ids) 

259 

260 def __iter__(self): 

261 return iter(self.ids) 

262 

263 def __getitem__(self, key): 

264 index = numpy.searchsorted(self.ids, key) 

265 if self.ids[index] != key: 

266 raise KeyError("Group with ID {0} not found".format(key)) 

267 return self.groups[index] 

268 

269 def where(self, predicate): 

270 """Return a new GroupView that contains only groups for which the 

271 given predicate function returns True. 

272 

273 The predicate function is called once for each group, and 

274 passed a single argument: the subset catalog for that group. 

275 

276 Parameters 

277 ---------- 

278 predicate : 

279 Function to identify which groups should be included in 

280 the output. 

281 

282 Returns 

283 ------- 

284 outGroupView : `lsst.afw.table.multiMatch.GroupView` 

285 Subset GroupView containing only groups that match the 

286 predicate. 

287 """ 

288 mask = numpy.zeros(len(self), dtype=bool) 

289 for i in range(len(self)): 

290 mask[i] = predicate(self.groups[i]) 

291 return type(self)(self.schema, self.ids[mask], self.groups[mask]) 

292 

293 def aggregate(self, function, field=None, dtype=float): 

294 """Run an aggregate function on each group, returning an array with 

295 one element for each group. 

296 

297 Parameters 

298 ---------- 

299 function : 

300 Callable object that computes the aggregate value. If 

301 `field` is None, called with the entire subset catalog as an 

302 argument. If `field` is not None, called with an array view 

303 into that field. 

304 field : `str`, optional 

305 A string name or Key object that indicates a single field the aggregate 

306 is computed over. 

307 dtype : 

308 Data type of the output array. 

309 

310 Returns 

311 ------- 

312 result : Array of `dtype` 

313 Aggregated values for each group. 

314 """ 

315 result = numpy.zeros(len(self), dtype=dtype) 

316 if field is not None: 

317 key = self.schema.find(field).key 

318 

319 def f(cat): 

320 return function(cat.get(key)) 

321 else: 

322 f = function 

323 for i in range(len(self)): 

324 result[i] = f(self.groups[i]) 

325 return result 

326 

327 def apply(self, function, field=None, dtype=float): 

328 """Run a non-aggregate function on each group, returning an array with 

329 one element for each record. 

330 

331 Parameters 

332 ---------- 

333 function : 

334 Callable object that computes the aggregate value. If field is None, 

335 called with the entire subset catalog as an argument. If field is not 

336 None, called with an array view into that field. 

337 field : `str` 

338 A string name or Key object that indicates a single field the aggregate 

339 is computed over. 

340 dtype : 

341 Data type for the output array. 

342 

343 Returns 

344 ------- 

345 result : `numpy.array` of `dtype` 

346 Result of the function calculated on an element-by-element basis. 

347 """ 

348 result = numpy.zeros(self.count, dtype=dtype) 

349 if field is not None: 

350 key = self.schema.find(field).key 

351 

352 def f(cat): 

353 return function(cat.get(key)) 

354 else: 

355 f = function 

356 last = 0 

357 for i in range(len(self)): 

358 next = last + len(self.groups[i]) 

359 result[last:next] = f(self.groups[i]) 

360 last = next 

361 return result