Coverage for python/lsst/pipe/tasks/mergeDetections.py: 22%

147 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-14 02:17 -0700

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["MergeDetectionsConfig", "MergeDetectionsTask"] 

23 

24import numpy as np 

25from numpy.lib.recfunctions import rec_join 

26 

27from .multiBandUtils import CullPeaksConfig 

28 

29import lsst.afw.detection as afwDetect 

30import lsst.afw.image as afwImage 

31import lsst.afw.table as afwTable 

32 

33from lsst.meas.algorithms import SkyObjectsTask 

34from lsst.skymap import BaseSkyMap 

35from lsst.pex.config import Config, Field, ListField, ConfigurableField, ConfigField 

36from lsst.pipe.base import (PipelineTask, PipelineTaskConfig, Struct, 

37 PipelineTaskConnections) 

38import lsst.pipe.base.connectionTypes as cT 

39from lsst.meas.base import SkyMapIdGeneratorConfig 

40 

41 

42def matchCatalogsExact(catalog1, catalog2, patch1=None, patch2=None): 

43 """Match two catalogs derived from the same mergeDet catalog. 

44 

45 When testing downstream features, like deblending methods/parameters 

46 and measurement algorithms/parameters, it is useful to to compare 

47 the same sources in two catalogs. In most cases this must be done 

48 by matching on either RA/DEC or XY positions, which occassionally 

49 will mismatch one source with another. 

50 

51 For a more robust solution, as long as the downstream catalog is 

52 derived from the same mergeDet catalog, exact source matching 

53 can be done via the unique ``(parent, deblend_peakID)`` 

54 combination. So this function performs this exact matching for 

55 all sources both catalogs. 

56 

57 Parameters 

58 ---------- 

59 catalog1, catalog2 : `lsst.afw.table.SourceCatalog` 

60 The two catalogs to merge 

61 patch1, patch2 : `array` of `int` 

62 Patch for each row, converted into an integer. 

63 

64 Returns 

65 ------- 

66 result : `list` of `lsst.afw.table.SourceMatch` 

67 List of matches for each source (using an inner join). 

68 """ 

69 # Only match the individual sources, the parents will 

70 # already be matched by the mergeDet catalog 

71 sidx1 = catalog1["parent"] != 0 

72 sidx2 = catalog2["parent"] != 0 

73 

74 # Create the keys used to merge the catalogs 

75 parents1 = np.array(catalog1["parent"][sidx1]) 

76 peaks1 = np.array(catalog1["deblend_peakId"][sidx1]) 

77 index1 = np.arange(len(catalog1))[sidx1] 

78 parents2 = np.array(catalog2["parent"][sidx2]) 

79 peaks2 = np.array(catalog2["deblend_peakId"][sidx2]) 

80 index2 = np.arange(len(catalog2))[sidx2] 

81 

82 if patch1 is not None: 

83 if patch2 is None: 

84 msg = ("If the catalogs are from different patches then patch1 and patch2 must be specified" 

85 ", got {} and {}").format(patch1, patch2) 

86 raise ValueError(msg) 

87 patch1 = patch1[sidx1] 

88 patch2 = patch2[sidx2] 

89 

90 key1 = np.rec.array((parents1, peaks1, patch1, index1), 

91 dtype=[('parent', np.int64), ('peakId', np.int32), 

92 ("patch", patch1.dtype), ("index", np.int32)]) 

93 key2 = np.rec.array((parents2, peaks2, patch2, index2), 

94 dtype=[('parent', np.int64), ('peakId', np.int32), 

95 ("patch", patch2.dtype), ("index", np.int32)]) 

96 matchColumns = ("parent", "peakId", "patch") 

97 else: 

98 key1 = np.rec.array((parents1, peaks1, index1), 

99 dtype=[('parent', np.int64), ('peakId', np.int32), ("index", np.int32)]) 

100 key2 = np.rec.array((parents2, peaks2, index2), 

101 dtype=[('parent', np.int64), ('peakId', np.int32), ("index", np.int32)]) 

102 matchColumns = ("parent", "peakId") 

103 # Match the two keys. 

104 # This line performs an inner join on the structured 

105 # arrays `key1` and `key2`, which stores their indices 

106 # as columns in a structured array. 

107 matched = rec_join(matchColumns, key1, key2, jointype="inner") 

108 

109 # Create the full index for both catalogs 

110 indices1 = matched["index1"] 

111 indices2 = matched["index2"] 

112 

113 # Re-index the resulting catalogs 

114 matches = [ 

115 afwTable.SourceMatch(catalog1[int(i1)], catalog2[int(i2)], 0.0) 

116 for i1, i2 in zip(indices1, indices2) 

117 ] 

118 

119 return matches 

120 

121 

122class MergeDetectionsConnections(PipelineTaskConnections, 

123 dimensions=("tract", "patch", "skymap"), 

124 defaultTemplates={"inputCoaddName": 'deep', "outputCoaddName": "deep"}): 

125 schema = cT.InitInput( 

126 doc="Schema of the input detection catalog", 

127 name="{inputCoaddName}Coadd_det_schema", 

128 storageClass="SourceCatalog" 

129 ) 

130 

131 outputSchema = cT.InitOutput( 

132 doc="Schema of the merged detection catalog", 

133 name="{outputCoaddName}Coadd_mergeDet_schema", 

134 storageClass="SourceCatalog" 

135 ) 

136 

137 outputPeakSchema = cT.InitOutput( 

138 doc="Output schema of the Footprint peak catalog", 

139 name="{outputCoaddName}Coadd_peak_schema", 

140 storageClass="PeakCatalog" 

141 ) 

142 

143 catalogs = cT.Input( 

144 doc="Detection Catalogs to be merged", 

145 name="{inputCoaddName}Coadd_det", 

146 storageClass="SourceCatalog", 

147 dimensions=("tract", "patch", "skymap", "band"), 

148 multiple=True 

149 ) 

150 

151 skyMap = cT.Input( 

152 doc="SkyMap to be used in merging", 

153 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

154 storageClass="SkyMap", 

155 dimensions=("skymap",), 

156 ) 

157 

158 outputCatalog = cT.Output( 

159 doc="Merged Detection catalog", 

160 name="{outputCoaddName}Coadd_mergeDet", 

161 storageClass="SourceCatalog", 

162 dimensions=("tract", "patch", "skymap"), 

163 ) 

164 

165 

166class MergeDetectionsConfig(PipelineTaskConfig, pipelineConnections=MergeDetectionsConnections): 

167 """Configuration parameters for the MergeDetectionsTask. 

168 """ 

169 minNewPeak = Field(dtype=float, default=1, 

170 doc="Minimum distance from closest peak to create a new one (in arcsec).") 

171 

172 maxSamePeak = Field(dtype=float, default=0.3, 

173 doc="When adding new catalogs to the merge, all peaks less than this distance " 

174 " (in arcsec) to an existing peak will be flagged as detected in that catalog.") 

175 cullPeaks = ConfigField(dtype=CullPeaksConfig, doc="Configuration for how to cull peaks.") 

176 

177 skyFilterName = Field(dtype=str, default="sky", 

178 doc="Name of `filter' used to label sky objects (e.g. flag merge_peak_sky is set)\n" 

179 "(N.b. should be in MergeMeasurementsConfig.pseudoFilterList)") 

180 skyObjects = ConfigurableField(target=SkyObjectsTask, doc="Generate sky objects") 

181 priorityList = ListField(dtype=str, default=[], 

182 doc="Priority-ordered list of filter bands for the merge.") 

183 coaddName = Field(dtype=str, default="deep", doc="Name of coadd") 

184 idGenerator = SkyMapIdGeneratorConfig.make_field() 

185 

186 def setDefaults(self): 

187 Config.setDefaults(self) 

188 self.skyObjects.avoidMask = ["DETECTED"] # Nothing else is available in our custom mask 

189 

190 def validate(self): 

191 super().validate() 

192 if len(self.priorityList) == 0: 

193 raise RuntimeError("No priority list provided") 

194 

195 

196class MergeDetectionsTask(PipelineTask): 

197 """Merge sources detected in coadds of exposures obtained with different filters. 

198 

199 Merge sources detected in coadds of exposures obtained with different 

200 filters. To perform photometry consistently across coadds in multiple 

201 filter bands, we create a master catalog of sources from all bands by 

202 merging the sources (peaks & footprints) detected in each coadd, while 

203 keeping track of which band each source originates in. The catalog merge 

204 is performed by 

205 `~lsst.afw.detection.FootprintMergeList.getMergedSourceCatalog`. Spurious 

206 peaks detected around bright objects are culled as described in 

207 `~lsst.pipe.tasks.multiBandUtils.CullPeaksConfig`. 

208 

209 MergeDetectionsTask is meant to be run after detecting sources in coadds 

210 generated for the chosen subset of the available bands. The purpose of the 

211 task is to merge sources (peaks & footprints) detected in the coadds 

212 generated from the chosen subset of filters. Subsequent tasks in the 

213 multi-band processing procedure will deblend the generated master list of 

214 sources and, eventually, perform forced photometry. 

215 

216 Parameters 

217 ---------- 

218 schema : `lsst.afw.table.Schema`, optional 

219 The schema of the detection catalogs used as input to this task. 

220 initInputs : `dict`, optional 

221 Dictionary that can contain a key ``schema`` containing the 

222 input schema. If present will override the value of ``schema``. 

223 **kwargs 

224 Additional keyword arguments. 

225 """ 

226 ConfigClass = MergeDetectionsConfig 

227 _DefaultName = "mergeCoaddDetections" 

228 

229 def __init__(self, schema=None, initInputs=None, **kwargs): 

230 super().__init__(**kwargs) 

231 

232 if initInputs is not None: 

233 schema = initInputs['schema'].schema 

234 

235 if schema is None: 

236 raise ValueError("No input schema or initInputs['schema'] provided.") 

237 

238 self.schema = schema 

239 

240 self.makeSubtask("skyObjects") 

241 

242 filterNames = list(self.config.priorityList) 

243 filterNames.append(self.config.skyFilterName) 

244 self.merged = afwDetect.FootprintMergeList(self.schema, filterNames) 

245 self.outputSchema = afwTable.SourceCatalog(self.schema) 

246 self.outputPeakSchema = afwDetect.PeakCatalog(self.merged.getPeakSchema()) 

247 

248 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

249 inputs = butlerQC.get(inputRefs) 

250 idGenerator = self.config.idGenerator.apply(butlerQC.quantum.dataId) 

251 inputs["skySeed"] = idGenerator.catalog_id 

252 inputs["idFactory"] = idGenerator.make_table_id_factory() 

253 catalogDict = {ref.dataId['band']: cat for ref, cat in zip(inputRefs.catalogs, 

254 inputs['catalogs'])} 

255 inputs['catalogs'] = catalogDict 

256 skyMap = inputs.pop('skyMap') 

257 # Can use the first dataId to find the tract and patch being worked on 

258 tractNumber = inputRefs.catalogs[0].dataId['tract'] 

259 tractInfo = skyMap[tractNumber] 

260 patchInfo = tractInfo.getPatchInfo(inputRefs.catalogs[0].dataId['patch']) 

261 skyInfo = Struct( 

262 skyMap=skyMap, 

263 tractInfo=tractInfo, 

264 patchInfo=patchInfo, 

265 wcs=tractInfo.getWcs(), 

266 bbox=patchInfo.getOuterBBox() 

267 ) 

268 inputs['skyInfo'] = skyInfo 

269 

270 outputs = self.run(**inputs) 

271 butlerQC.put(outputs, outputRefs) 

272 

273 def run(self, catalogs, skyInfo, idFactory, skySeed): 

274 """Merge multiple catalogs. 

275 

276 After ordering the catalogs and filters in priority order, 

277 ``getMergedSourceCatalog`` of the 

278 `~lsst.afw.detection.FootprintMergeList` created by ``__init__`` is 

279 used to perform the actual merging. Finally, `cullPeaks` is used to 

280 remove garbage peaks detected around bright objects. 

281 

282 Parameters 

283 ---------- 

284 catalogs : `lsst.afw.table.SourceCatalog` 

285 Catalogs to be merged. 

286 mergedList : `lsst.afw.table.SourceCatalog` 

287 Merged catalogs. 

288 

289 Returns 

290 ------- 

291 result : `lsst.pipe.base.Struct` 

292 Results as a struct with attributes: 

293 

294 ``outputCatalog`` 

295 Merged catalogs (`lsst.afw.table.SourceCatalog`). 

296 """ 

297 # Convert distance to tract coordinate 

298 tractWcs = skyInfo.wcs 

299 peakDistance = self.config.minNewPeak / tractWcs.getPixelScale().asArcseconds() 

300 samePeakDistance = self.config.maxSamePeak / tractWcs.getPixelScale().asArcseconds() 

301 

302 # Put catalogs, filters in priority order 

303 orderedCatalogs = [catalogs[band] for band in self.config.priorityList if band in catalogs.keys()] 

304 orderedBands = [band for band in self.config.priorityList if band in catalogs.keys()] 

305 

306 mergedList = self.merged.getMergedSourceCatalog(orderedCatalogs, orderedBands, peakDistance, 

307 self.schema, idFactory, 

308 samePeakDistance) 

309 

310 # 

311 # Add extra sources that correspond to blank sky 

312 # 

313 skySourceFootprints = self.getSkySourceFootprints(mergedList, skyInfo, skySeed) 

314 if skySourceFootprints: 

315 key = mergedList.schema.find("merge_footprint_%s" % self.config.skyFilterName).key 

316 for foot in skySourceFootprints: 

317 s = mergedList.addNew() 

318 s.setFootprint(foot) 

319 s.set(key, True) 

320 

321 # Sort Peaks from brightest to faintest 

322 for record in mergedList: 

323 record.getFootprint().sortPeaks() 

324 self.log.info("Merged to %d sources", len(mergedList)) 

325 # Attempt to remove garbage peaks 

326 self.cullPeaks(mergedList) 

327 return Struct(outputCatalog=mergedList) 

328 

329 def cullPeaks(self, catalog): 

330 """Attempt to remove garbage peaks (mostly on the outskirts of large blends). 

331 

332 Parameters 

333 ---------- 

334 catalog : `lsst.afw.table.SourceCatalog` 

335 Source catalog. 

336 """ 

337 keys = [item.key for item in self.merged.getPeakSchema().extract("merge_peak_*").values()] 

338 assert len(keys) > 0, "Error finding flags that associate peaks with their detection bands." 

339 totalPeaks = 0 

340 culledPeaks = 0 

341 for parentSource in catalog: 

342 # Make a list copy so we can clear the attached PeakCatalog and append the ones we're keeping 

343 # to it (which is easier than deleting as we iterate). 

344 keptPeaks = parentSource.getFootprint().getPeaks() 

345 oldPeaks = list(keptPeaks) 

346 keptPeaks.clear() 

347 familySize = len(oldPeaks) 

348 totalPeaks += familySize 

349 for rank, peak in enumerate(oldPeaks): 

350 if ((rank < self.config.cullPeaks.rankSufficient) 

351 or (sum([peak.get(k) for k in keys]) >= self.config.cullPeaks.nBandsSufficient) 

352 or (rank < self.config.cullPeaks.rankConsidered 

353 and rank < self.config.cullPeaks.rankNormalizedConsidered * familySize)): 

354 keptPeaks.append(peak) 

355 else: 

356 culledPeaks += 1 

357 self.log.info("Culled %d of %d peaks", culledPeaks, totalPeaks) 

358 

359 def getSkySourceFootprints(self, mergedList, skyInfo, seed): 

360 """Return a list of Footprints of sky objects which don't overlap with anything in mergedList. 

361 

362 Parameters 

363 ---------- 

364 mergedList : `lsst.afw.table.SourceCatalog` 

365 The merged Footprints from all the input bands. 

366 skyInfo : `lsst.pipe.base.Struct` 

367 A description of the patch. 

368 seed : `int` 

369 Seed for the random number generator. 

370 """ 

371 mask = afwImage.Mask(skyInfo.patchInfo.getOuterBBox()) 

372 detected = mask.getPlaneBitMask("DETECTED") 

373 for s in mergedList: 

374 s.getFootprint().spans.setMask(mask, detected) 

375 

376 footprints = self.skyObjects.run(mask, seed) 

377 if not footprints: 

378 return footprints 

379 

380 # Need to convert the peak catalog's schema so we can set the "merge_peak_<skyFilterName>" flags 

381 schema = self.merged.getPeakSchema() 

382 mergeKey = schema.find("merge_peak_%s" % self.config.skyFilterName).key 

383 converted = [] 

384 for oldFoot in footprints: 

385 assert len(oldFoot.getPeaks()) == 1, "Should be a single peak only" 

386 peak = oldFoot.getPeaks()[0] 

387 newFoot = afwDetect.Footprint(oldFoot.spans, schema) 

388 newFoot.addPeak(peak.getFx(), peak.getFy(), peak.getPeakValue()) 

389 newFoot.getPeaks()[0].set(mergeKey, True) 

390 converted.append(newFoot) 

391 

392 return converted