Coverage for python/lsst/pipe/tasks/mergeDetections.py: 21%

150 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-09 12:16 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["MergeDetectionsConfig", "MergeDetectionsTask"] 

23 

24import numpy as np 

25from numpy.lib.recfunctions import rec_join 

26import warnings 

27 

28from .multiBandUtils import CullPeaksConfig 

29 

30import lsst.afw.detection as afwDetect 

31import lsst.afw.image as afwImage 

32import lsst.afw.table as afwTable 

33 

34from lsst.meas.algorithms import SkyObjectsTask 

35from lsst.skymap import BaseSkyMap 

36from lsst.pex.config import Config, Field, ListField, ConfigurableField, ConfigField 

37from lsst.pipe.base import (PipelineTask, PipelineTaskConfig, Struct, 

38 PipelineTaskConnections) 

39import lsst.pipe.base.connectionTypes as cT 

40from lsst.obs.base import ExposureIdInfo 

41 

42 

43def matchCatalogsExact(catalog1, catalog2, patch1=None, patch2=None): 

44 """Match two catalogs derived from the same mergeDet catalog. 

45 

46 When testing downstream features, like deblending methods/parameters 

47 and measurement algorithms/parameters, it is useful to to compare 

48 the same sources in two catalogs. In most cases this must be done 

49 by matching on either RA/DEC or XY positions, which occassionally 

50 will mismatch one source with another. 

51 

52 For a more robust solution, as long as the downstream catalog is 

53 derived from the same mergeDet catalog, exact source matching 

54 can be done via the unique ``(parent, deblend_peakID)`` 

55 combination. So this function performs this exact matching for 

56 all sources both catalogs. 

57 

58 Parameters 

59 ---------- 

60 catalog1, catalog2 : `lsst.afw.table.SourceCatalog` 

61 The two catalogs to merge 

62 patch1, patch2 : `array` of `int` 

63 Patch for each row, converted into an integer. 

64 

65 Returns 

66 ------- 

67 result : `list` of `lsst.afw.table.SourceMatch` 

68 List of matches for each source (using an inner join). 

69 """ 

70 # Only match the individual sources, the parents will 

71 # already be matched by the mergeDet catalog 

72 sidx1 = catalog1["parent"] != 0 

73 sidx2 = catalog2["parent"] != 0 

74 

75 # Create the keys used to merge the catalogs 

76 parents1 = np.array(catalog1["parent"][sidx1]) 

77 peaks1 = np.array(catalog1["deblend_peakId"][sidx1]) 

78 index1 = np.arange(len(catalog1))[sidx1] 

79 parents2 = np.array(catalog2["parent"][sidx2]) 

80 peaks2 = np.array(catalog2["deblend_peakId"][sidx2]) 

81 index2 = np.arange(len(catalog2))[sidx2] 

82 

83 if patch1 is not None: 

84 if patch2 is None: 

85 msg = ("If the catalogs are from different patches then patch1 and patch2 must be specified" 

86 ", got {} and {}").format(patch1, patch2) 

87 raise ValueError(msg) 

88 patch1 = patch1[sidx1] 

89 patch2 = patch2[sidx2] 

90 

91 key1 = np.rec.array((parents1, peaks1, patch1, index1), 

92 dtype=[('parent', np.int64), ('peakId', np.int32), 

93 ("patch", patch1.dtype), ("index", np.int32)]) 

94 key2 = np.rec.array((parents2, peaks2, patch2, index2), 

95 dtype=[('parent', np.int64), ('peakId', np.int32), 

96 ("patch", patch2.dtype), ("index", np.int32)]) 

97 matchColumns = ("parent", "peakId", "patch") 

98 else: 

99 key1 = np.rec.array((parents1, peaks1, index1), 

100 dtype=[('parent', np.int64), ('peakId', np.int32), ("index", np.int32)]) 

101 key2 = np.rec.array((parents2, peaks2, index2), 

102 dtype=[('parent', np.int64), ('peakId', np.int32), ("index", np.int32)]) 

103 matchColumns = ("parent", "peakId") 

104 # Match the two keys. 

105 # This line performs an inner join on the structured 

106 # arrays `key1` and `key2`, which stores their indices 

107 # as columns in a structured array. 

108 matched = rec_join(matchColumns, key1, key2, jointype="inner") 

109 

110 # Create the full index for both catalogs 

111 indices1 = matched["index1"] 

112 indices2 = matched["index2"] 

113 

114 # Re-index the resulting catalogs 

115 matches = [ 

116 afwTable.SourceMatch(catalog1[int(i1)], catalog2[int(i2)], 0.0) 

117 for i1, i2 in zip(indices1, indices2) 

118 ] 

119 

120 return matches 

121 

122 

123class MergeDetectionsConnections(PipelineTaskConnections, 

124 dimensions=("tract", "patch", "skymap"), 

125 defaultTemplates={"inputCoaddName": 'deep', "outputCoaddName": "deep"}): 

126 schema = cT.InitInput( 

127 doc="Schema of the input detection catalog", 

128 name="{inputCoaddName}Coadd_det_schema", 

129 storageClass="SourceCatalog" 

130 ) 

131 

132 outputSchema = cT.InitOutput( 

133 doc="Schema of the merged detection catalog", 

134 name="{outputCoaddName}Coadd_mergeDet_schema", 

135 storageClass="SourceCatalog" 

136 ) 

137 

138 outputPeakSchema = cT.InitOutput( 

139 doc="Output schema of the Footprint peak catalog", 

140 name="{outputCoaddName}Coadd_peak_schema", 

141 storageClass="PeakCatalog" 

142 ) 

143 

144 catalogs = cT.Input( 

145 doc="Detection Catalogs to be merged", 

146 name="{inputCoaddName}Coadd_det", 

147 storageClass="SourceCatalog", 

148 dimensions=("tract", "patch", "skymap", "band"), 

149 multiple=True 

150 ) 

151 

152 skyMap = cT.Input( 

153 doc="SkyMap to be used in merging", 

154 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

155 storageClass="SkyMap", 

156 dimensions=("skymap",), 

157 ) 

158 

159 outputCatalog = cT.Output( 

160 doc="Merged Detection catalog", 

161 name="{outputCoaddName}Coadd_mergeDet", 

162 storageClass="SourceCatalog", 

163 dimensions=("tract", "patch", "skymap"), 

164 ) 

165 

166 

167class MergeDetectionsConfig(PipelineTaskConfig, pipelineConnections=MergeDetectionsConnections): 

168 """Configuration parameters for the MergeDetectionsTask. 

169 """ 

170 minNewPeak = Field(dtype=float, default=1, 

171 doc="Minimum distance from closest peak to create a new one (in arcsec).") 

172 

173 maxSamePeak = Field(dtype=float, default=0.3, 

174 doc="When adding new catalogs to the merge, all peaks less than this distance " 

175 " (in arcsec) to an existing peak will be flagged as detected in that catalog.") 

176 cullPeaks = ConfigField(dtype=CullPeaksConfig, doc="Configuration for how to cull peaks.") 

177 

178 skyFilterName = Field(dtype=str, default="sky", 

179 doc="Name of `filter' used to label sky objects (e.g. flag merge_peak_sky is set)\n" 

180 "(N.b. should be in MergeMeasurementsConfig.pseudoFilterList)") 

181 skyObjects = ConfigurableField(target=SkyObjectsTask, doc="Generate sky objects") 

182 priorityList = ListField(dtype=str, default=[], 

183 doc="Priority-ordered list of filter bands for the merge.") 

184 coaddName = Field(dtype=str, default="deep", doc="Name of coadd") 

185 

186 def setDefaults(self): 

187 Config.setDefaults(self) 

188 self.skyObjects.avoidMask = ["DETECTED"] # Nothing else is available in our custom mask 

189 

190 def validate(self): 

191 super().validate() 

192 if len(self.priorityList) == 0: 

193 raise RuntimeError("No priority list provided") 

194 

195 

196class MergeDetectionsTask(PipelineTask): 

197 """Merge sources detected in coadds of exposures obtained with different filters. 

198 

199 Merge sources detected in coadds of exposures obtained with different 

200 filters. To perform photometry consistently across coadds in multiple 

201 filter bands, we create a master catalog of sources from all bands by 

202 merging the sources (peaks & footprints) detected in each coadd, while 

203 keeping track of which band each source originates in. The catalog merge 

204 is performed by 

205 `~lsst.afw.detection.FootprintMergeList.getMergedSourceCatalog`. Spurious 

206 peaks detected around bright objects are culled as described in 

207 `~lsst.pipe.tasks.multiBandUtils.CullPeaksConfig`. 

208 

209 MergeDetectionsTask is meant to be run after detecting sources in coadds 

210 generated for the chosen subset of the available bands. The purpose of the 

211 task is to merge sources (peaks & footprints) detected in the coadds 

212 generated from the chosen subset of filters. Subsequent tasks in the 

213 multi-band processing procedure will deblend the generated master list of 

214 sources and, eventually, perform forced photometry. 

215 

216 Parameters 

217 ---------- 

218 butler : `None`, optional 

219 Compatibility parameter. Should always be `None`. 

220 schema : `lsst.afw.table.Schema`, optional 

221 The schema of the detection catalogs used as input to this task. 

222 initInputs : `dict`, optional 

223 Dictionary that can contain a key ``schema`` containing the 

224 input schema. If present will override the value of ``schema``. 

225 **kwargs 

226 Additional keyword arguments. 

227 """ 

228 ConfigClass = MergeDetectionsConfig 

229 _DefaultName = "mergeCoaddDetections" 

230 

231 def __init__(self, butler=None, schema=None, initInputs=None, **kwargs): 

232 super().__init__(**kwargs) 

233 

234 if butler is not None: 

235 warnings.warn("The 'butler' parameter is no longer used and can be safely removed.", 

236 category=FutureWarning, stacklevel=2) 

237 butler = None 

238 

239 if initInputs is not None: 

240 schema = initInputs['schema'].schema 

241 

242 if schema is None: 

243 raise ValueError("No input schema or initInputs['schema'] provided.") 

244 

245 self.schema = schema 

246 

247 self.makeSubtask("skyObjects") 

248 

249 filterNames = list(self.config.priorityList) 

250 filterNames.append(self.config.skyFilterName) 

251 self.merged = afwDetect.FootprintMergeList(self.schema, filterNames) 

252 self.outputSchema = afwTable.SourceCatalog(self.schema) 

253 self.outputPeakSchema = afwDetect.PeakCatalog(self.merged.getPeakSchema()) 

254 

255 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

256 inputs = butlerQC.get(inputRefs) 

257 exposureIdInfo = ExposureIdInfo.fromDataId(butlerQC.quantum.dataId, "tract_patch") 

258 inputs["skySeed"] = exposureIdInfo.expId 

259 inputs["idFactory"] = exposureIdInfo.makeSourceIdFactory() 

260 catalogDict = {ref.dataId['band']: cat for ref, cat in zip(inputRefs.catalogs, 

261 inputs['catalogs'])} 

262 inputs['catalogs'] = catalogDict 

263 skyMap = inputs.pop('skyMap') 

264 # Can use the first dataId to find the tract and patch being worked on 

265 tractNumber = inputRefs.catalogs[0].dataId['tract'] 

266 tractInfo = skyMap[tractNumber] 

267 patchInfo = tractInfo.getPatchInfo(inputRefs.catalogs[0].dataId['patch']) 

268 skyInfo = Struct( 

269 skyMap=skyMap, 

270 tractInfo=tractInfo, 

271 patchInfo=patchInfo, 

272 wcs=tractInfo.getWcs(), 

273 bbox=patchInfo.getOuterBBox() 

274 ) 

275 inputs['skyInfo'] = skyInfo 

276 

277 outputs = self.run(**inputs) 

278 butlerQC.put(outputs, outputRefs) 

279 

280 def run(self, catalogs, skyInfo, idFactory, skySeed): 

281 """Merge multiple catalogs. 

282 

283 After ordering the catalogs and filters in priority order, 

284 ``getMergedSourceCatalog`` of the 

285 `~lsst.afw.detection.FootprintMergeList` created by ``__init__`` is 

286 used to perform the actual merging. Finally, `cullPeaks` is used to 

287 remove garbage peaks detected around bright objects. 

288 

289 Parameters 

290 ---------- 

291 catalogs : `lsst.afw.table.SourceCatalog` 

292 Catalogs to be merged. 

293 mergedList : `lsst.afw.table.SourceCatalog` 

294 Merged catalogs. 

295 

296 Returns 

297 ------- 

298 result : `lsst.pipe.base.Struct` 

299 Results as a struct with attributes: 

300 

301 ``outputCatalog`` 

302 Merged catalogs (`lsst.afw.table.SourceCatalog`). 

303 """ 

304 # Convert distance to tract coordinate 

305 tractWcs = skyInfo.wcs 

306 peakDistance = self.config.minNewPeak / tractWcs.getPixelScale().asArcseconds() 

307 samePeakDistance = self.config.maxSamePeak / tractWcs.getPixelScale().asArcseconds() 

308 

309 # Put catalogs, filters in priority order 

310 orderedCatalogs = [catalogs[band] for band in self.config.priorityList if band in catalogs.keys()] 

311 orderedBands = [band for band in self.config.priorityList if band in catalogs.keys()] 

312 

313 mergedList = self.merged.getMergedSourceCatalog(orderedCatalogs, orderedBands, peakDistance, 

314 self.schema, idFactory, 

315 samePeakDistance) 

316 

317 # 

318 # Add extra sources that correspond to blank sky 

319 # 

320 skySourceFootprints = self.getSkySourceFootprints(mergedList, skyInfo, skySeed) 

321 if skySourceFootprints: 

322 key = mergedList.schema.find("merge_footprint_%s" % self.config.skyFilterName).key 

323 for foot in skySourceFootprints: 

324 s = mergedList.addNew() 

325 s.setFootprint(foot) 

326 s.set(key, True) 

327 

328 # Sort Peaks from brightest to faintest 

329 for record in mergedList: 

330 record.getFootprint().sortPeaks() 

331 self.log.info("Merged to %d sources", len(mergedList)) 

332 # Attempt to remove garbage peaks 

333 self.cullPeaks(mergedList) 

334 return Struct(outputCatalog=mergedList) 

335 

336 def cullPeaks(self, catalog): 

337 """Attempt to remove garbage peaks (mostly on the outskirts of large blends). 

338 

339 Parameters 

340 ---------- 

341 catalog : `lsst.afw.table.SourceCatalog` 

342 Source catalog. 

343 """ 

344 keys = [item.key for item in self.merged.getPeakSchema().extract("merge_peak_*").values()] 

345 assert len(keys) > 0, "Error finding flags that associate peaks with their detection bands." 

346 totalPeaks = 0 

347 culledPeaks = 0 

348 for parentSource in catalog: 

349 # Make a list copy so we can clear the attached PeakCatalog and append the ones we're keeping 

350 # to it (which is easier than deleting as we iterate). 

351 keptPeaks = parentSource.getFootprint().getPeaks() 

352 oldPeaks = list(keptPeaks) 

353 keptPeaks.clear() 

354 familySize = len(oldPeaks) 

355 totalPeaks += familySize 

356 for rank, peak in enumerate(oldPeaks): 

357 if ((rank < self.config.cullPeaks.rankSufficient) 

358 or (sum([peak.get(k) for k in keys]) >= self.config.cullPeaks.nBandsSufficient) 

359 or (rank < self.config.cullPeaks.rankConsidered 

360 and rank < self.config.cullPeaks.rankNormalizedConsidered * familySize)): 

361 keptPeaks.append(peak) 

362 else: 

363 culledPeaks += 1 

364 self.log.info("Culled %d of %d peaks", culledPeaks, totalPeaks) 

365 

366 def getSkySourceFootprints(self, mergedList, skyInfo, seed): 

367 """Return a list of Footprints of sky objects which don't overlap with anything in mergedList. 

368 

369 Parameters 

370 ---------- 

371 mergedList : `lsst.afw.table.SourceCatalog` 

372 The merged Footprints from all the input bands. 

373 skyInfo : `lsst.pipe.base.Struct` 

374 A description of the patch. 

375 seed : `int` 

376 Seed for the random number generator. 

377 """ 

378 mask = afwImage.Mask(skyInfo.patchInfo.getOuterBBox()) 

379 detected = mask.getPlaneBitMask("DETECTED") 

380 for s in mergedList: 

381 s.getFootprint().spans.setMask(mask, detected) 

382 

383 footprints = self.skyObjects.run(mask, seed) 

384 if not footprints: 

385 return footprints 

386 

387 # Need to convert the peak catalog's schema so we can set the "merge_peak_<skyFilterName>" flags 

388 schema = self.merged.getPeakSchema() 

389 mergeKey = schema.find("merge_peak_%s" % self.config.skyFilterName).key 

390 converted = [] 

391 for oldFoot in footprints: 

392 assert len(oldFoot.getPeaks()) == 1, "Should be a single peak only" 

393 peak = oldFoot.getPeaks()[0] 

394 newFoot = afwDetect.Footprint(oldFoot.spans, schema) 

395 newFoot.addPeak(peak.getFx(), peak.getFy(), peak.getPeakValue()) 

396 newFoot.getPeaks()[0].set(mergeKey, True) 

397 converted.append(newFoot) 

398 

399 return converted