Coverage for python/lsst/pipe/tasks/mergeDetections.py: 21%

151 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-28 04:54 -0700

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["MergeDetectionsConfig", "MergeDetectionsTask"] 

23 

24import numpy as np 

25from numpy.lib.recfunctions import rec_join 

26import warnings 

27 

28from .multiBandUtils import CullPeaksConfig 

29 

30import lsst.afw.detection as afwDetect 

31import lsst.afw.image as afwImage 

32import lsst.afw.table as afwTable 

33 

34from lsst.meas.algorithms import SkyObjectsTask 

35from lsst.skymap import BaseSkyMap 

36from lsst.pex.config import Config, Field, ListField, ConfigurableField, ConfigField 

37from lsst.pipe.base import (PipelineTask, PipelineTaskConfig, Struct, 

38 PipelineTaskConnections) 

39import lsst.pipe.base.connectionTypes as cT 

40from lsst.meas.base import SkyMapIdGeneratorConfig 

41 

42 

43def matchCatalogsExact(catalog1, catalog2, patch1=None, patch2=None): 

44 """Match two catalogs derived from the same mergeDet catalog. 

45 

46 When testing downstream features, like deblending methods/parameters 

47 and measurement algorithms/parameters, it is useful to to compare 

48 the same sources in two catalogs. In most cases this must be done 

49 by matching on either RA/DEC or XY positions, which occassionally 

50 will mismatch one source with another. 

51 

52 For a more robust solution, as long as the downstream catalog is 

53 derived from the same mergeDet catalog, exact source matching 

54 can be done via the unique ``(parent, deblend_peakID)`` 

55 combination. So this function performs this exact matching for 

56 all sources both catalogs. 

57 

58 Parameters 

59 ---------- 

60 catalog1, catalog2 : `lsst.afw.table.SourceCatalog` 

61 The two catalogs to merge 

62 patch1, patch2 : `array` of `int` 

63 Patch for each row, converted into an integer. 

64 

65 Returns 

66 ------- 

67 result : `list` of `lsst.afw.table.SourceMatch` 

68 List of matches for each source (using an inner join). 

69 """ 

70 # Only match the individual sources, the parents will 

71 # already be matched by the mergeDet catalog 

72 sidx1 = catalog1["parent"] != 0 

73 sidx2 = catalog2["parent"] != 0 

74 

75 # Create the keys used to merge the catalogs 

76 parents1 = np.array(catalog1["parent"][sidx1]) 

77 peaks1 = np.array(catalog1["deblend_peakId"][sidx1]) 

78 index1 = np.arange(len(catalog1))[sidx1] 

79 parents2 = np.array(catalog2["parent"][sidx2]) 

80 peaks2 = np.array(catalog2["deblend_peakId"][sidx2]) 

81 index2 = np.arange(len(catalog2))[sidx2] 

82 

83 if patch1 is not None: 

84 if patch2 is None: 

85 msg = ("If the catalogs are from different patches then patch1 and patch2 must be specified" 

86 ", got {} and {}").format(patch1, patch2) 

87 raise ValueError(msg) 

88 patch1 = patch1[sidx1] 

89 patch2 = patch2[sidx2] 

90 

91 key1 = np.rec.array((parents1, peaks1, patch1, index1), 

92 dtype=[('parent', np.int64), ('peakId', np.int32), 

93 ("patch", patch1.dtype), ("index", np.int32)]) 

94 key2 = np.rec.array((parents2, peaks2, patch2, index2), 

95 dtype=[('parent', np.int64), ('peakId', np.int32), 

96 ("patch", patch2.dtype), ("index", np.int32)]) 

97 matchColumns = ("parent", "peakId", "patch") 

98 else: 

99 key1 = np.rec.array((parents1, peaks1, index1), 

100 dtype=[('parent', np.int64), ('peakId', np.int32), ("index", np.int32)]) 

101 key2 = np.rec.array((parents2, peaks2, index2), 

102 dtype=[('parent', np.int64), ('peakId', np.int32), ("index", np.int32)]) 

103 matchColumns = ("parent", "peakId") 

104 # Match the two keys. 

105 # This line performs an inner join on the structured 

106 # arrays `key1` and `key2`, which stores their indices 

107 # as columns in a structured array. 

108 matched = rec_join(matchColumns, key1, key2, jointype="inner") 

109 

110 # Create the full index for both catalogs 

111 indices1 = matched["index1"] 

112 indices2 = matched["index2"] 

113 

114 # Re-index the resulting catalogs 

115 matches = [ 

116 afwTable.SourceMatch(catalog1[int(i1)], catalog2[int(i2)], 0.0) 

117 for i1, i2 in zip(indices1, indices2) 

118 ] 

119 

120 return matches 

121 

122 

123class MergeDetectionsConnections(PipelineTaskConnections, 

124 dimensions=("tract", "patch", "skymap"), 

125 defaultTemplates={"inputCoaddName": 'deep', "outputCoaddName": "deep"}): 

126 schema = cT.InitInput( 

127 doc="Schema of the input detection catalog", 

128 name="{inputCoaddName}Coadd_det_schema", 

129 storageClass="SourceCatalog" 

130 ) 

131 

132 outputSchema = cT.InitOutput( 

133 doc="Schema of the merged detection catalog", 

134 name="{outputCoaddName}Coadd_mergeDet_schema", 

135 storageClass="SourceCatalog" 

136 ) 

137 

138 outputPeakSchema = cT.InitOutput( 

139 doc="Output schema of the Footprint peak catalog", 

140 name="{outputCoaddName}Coadd_peak_schema", 

141 storageClass="PeakCatalog" 

142 ) 

143 

144 catalogs = cT.Input( 

145 doc="Detection Catalogs to be merged", 

146 name="{inputCoaddName}Coadd_det", 

147 storageClass="SourceCatalog", 

148 dimensions=("tract", "patch", "skymap", "band"), 

149 multiple=True 

150 ) 

151 

152 skyMap = cT.Input( 

153 doc="SkyMap to be used in merging", 

154 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

155 storageClass="SkyMap", 

156 dimensions=("skymap",), 

157 ) 

158 

159 outputCatalog = cT.Output( 

160 doc="Merged Detection catalog", 

161 name="{outputCoaddName}Coadd_mergeDet", 

162 storageClass="SourceCatalog", 

163 dimensions=("tract", "patch", "skymap"), 

164 ) 

165 

166 

167class MergeDetectionsConfig(PipelineTaskConfig, pipelineConnections=MergeDetectionsConnections): 

168 """Configuration parameters for the MergeDetectionsTask. 

169 """ 

170 minNewPeak = Field(dtype=float, default=1, 

171 doc="Minimum distance from closest peak to create a new one (in arcsec).") 

172 

173 maxSamePeak = Field(dtype=float, default=0.3, 

174 doc="When adding new catalogs to the merge, all peaks less than this distance " 

175 " (in arcsec) to an existing peak will be flagged as detected in that catalog.") 

176 cullPeaks = ConfigField(dtype=CullPeaksConfig, doc="Configuration for how to cull peaks.") 

177 

178 skyFilterName = Field(dtype=str, default="sky", 

179 doc="Name of `filter' used to label sky objects (e.g. flag merge_peak_sky is set)\n" 

180 "(N.b. should be in MergeMeasurementsConfig.pseudoFilterList)") 

181 skyObjects = ConfigurableField(target=SkyObjectsTask, doc="Generate sky objects") 

182 priorityList = ListField(dtype=str, default=[], 

183 doc="Priority-ordered list of filter bands for the merge.") 

184 coaddName = Field(dtype=str, default="deep", doc="Name of coadd") 

185 idGenerator = SkyMapIdGeneratorConfig.make_field() 

186 

187 def setDefaults(self): 

188 Config.setDefaults(self) 

189 self.skyObjects.avoidMask = ["DETECTED"] # Nothing else is available in our custom mask 

190 

191 def validate(self): 

192 super().validate() 

193 if len(self.priorityList) == 0: 

194 raise RuntimeError("No priority list provided") 

195 

196 

197class MergeDetectionsTask(PipelineTask): 

198 """Merge sources detected in coadds of exposures obtained with different filters. 

199 

200 Merge sources detected in coadds of exposures obtained with different 

201 filters. To perform photometry consistently across coadds in multiple 

202 filter bands, we create a master catalog of sources from all bands by 

203 merging the sources (peaks & footprints) detected in each coadd, while 

204 keeping track of which band each source originates in. The catalog merge 

205 is performed by 

206 `~lsst.afw.detection.FootprintMergeList.getMergedSourceCatalog`. Spurious 

207 peaks detected around bright objects are culled as described in 

208 `~lsst.pipe.tasks.multiBandUtils.CullPeaksConfig`. 

209 

210 MergeDetectionsTask is meant to be run after detecting sources in coadds 

211 generated for the chosen subset of the available bands. The purpose of the 

212 task is to merge sources (peaks & footprints) detected in the coadds 

213 generated from the chosen subset of filters. Subsequent tasks in the 

214 multi-band processing procedure will deblend the generated master list of 

215 sources and, eventually, perform forced photometry. 

216 

217 Parameters 

218 ---------- 

219 butler : `None`, optional 

220 Compatibility parameter. Should always be `None`. 

221 schema : `lsst.afw.table.Schema`, optional 

222 The schema of the detection catalogs used as input to this task. 

223 initInputs : `dict`, optional 

224 Dictionary that can contain a key ``schema`` containing the 

225 input schema. If present will override the value of ``schema``. 

226 **kwargs 

227 Additional keyword arguments. 

228 """ 

229 ConfigClass = MergeDetectionsConfig 

230 _DefaultName = "mergeCoaddDetections" 

231 

232 def __init__(self, butler=None, schema=None, initInputs=None, **kwargs): 

233 super().__init__(**kwargs) 

234 

235 if butler is not None: 

236 warnings.warn("The 'butler' parameter is no longer used and can be safely removed.", 

237 category=FutureWarning, stacklevel=2) 

238 butler = None 

239 

240 if initInputs is not None: 

241 schema = initInputs['schema'].schema 

242 

243 if schema is None: 

244 raise ValueError("No input schema or initInputs['schema'] provided.") 

245 

246 self.schema = schema 

247 

248 self.makeSubtask("skyObjects") 

249 

250 filterNames = list(self.config.priorityList) 

251 filterNames.append(self.config.skyFilterName) 

252 self.merged = afwDetect.FootprintMergeList(self.schema, filterNames) 

253 self.outputSchema = afwTable.SourceCatalog(self.schema) 

254 self.outputPeakSchema = afwDetect.PeakCatalog(self.merged.getPeakSchema()) 

255 

256 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

257 inputs = butlerQC.get(inputRefs) 

258 idGenerator = self.config.idGenerator.apply(butlerQC.quantum.dataId) 

259 inputs["skySeed"] = idGenerator.catalog_id 

260 inputs["idFactory"] = idGenerator.make_table_id_factory() 

261 catalogDict = {ref.dataId['band']: cat for ref, cat in zip(inputRefs.catalogs, 

262 inputs['catalogs'])} 

263 inputs['catalogs'] = catalogDict 

264 skyMap = inputs.pop('skyMap') 

265 # Can use the first dataId to find the tract and patch being worked on 

266 tractNumber = inputRefs.catalogs[0].dataId['tract'] 

267 tractInfo = skyMap[tractNumber] 

268 patchInfo = tractInfo.getPatchInfo(inputRefs.catalogs[0].dataId['patch']) 

269 skyInfo = Struct( 

270 skyMap=skyMap, 

271 tractInfo=tractInfo, 

272 patchInfo=patchInfo, 

273 wcs=tractInfo.getWcs(), 

274 bbox=patchInfo.getOuterBBox() 

275 ) 

276 inputs['skyInfo'] = skyInfo 

277 

278 outputs = self.run(**inputs) 

279 butlerQC.put(outputs, outputRefs) 

280 

281 def run(self, catalogs, skyInfo, idFactory, skySeed): 

282 """Merge multiple catalogs. 

283 

284 After ordering the catalogs and filters in priority order, 

285 ``getMergedSourceCatalog`` of the 

286 `~lsst.afw.detection.FootprintMergeList` created by ``__init__`` is 

287 used to perform the actual merging. Finally, `cullPeaks` is used to 

288 remove garbage peaks detected around bright objects. 

289 

290 Parameters 

291 ---------- 

292 catalogs : `lsst.afw.table.SourceCatalog` 

293 Catalogs to be merged. 

294 mergedList : `lsst.afw.table.SourceCatalog` 

295 Merged catalogs. 

296 

297 Returns 

298 ------- 

299 result : `lsst.pipe.base.Struct` 

300 Results as a struct with attributes: 

301 

302 ``outputCatalog`` 

303 Merged catalogs (`lsst.afw.table.SourceCatalog`). 

304 """ 

305 # Convert distance to tract coordinate 

306 tractWcs = skyInfo.wcs 

307 peakDistance = self.config.minNewPeak / tractWcs.getPixelScale().asArcseconds() 

308 samePeakDistance = self.config.maxSamePeak / tractWcs.getPixelScale().asArcseconds() 

309 

310 # Put catalogs, filters in priority order 

311 orderedCatalogs = [catalogs[band] for band in self.config.priorityList if band in catalogs.keys()] 

312 orderedBands = [band for band in self.config.priorityList if band in catalogs.keys()] 

313 

314 mergedList = self.merged.getMergedSourceCatalog(orderedCatalogs, orderedBands, peakDistance, 

315 self.schema, idFactory, 

316 samePeakDistance) 

317 

318 # 

319 # Add extra sources that correspond to blank sky 

320 # 

321 skySourceFootprints = self.getSkySourceFootprints(mergedList, skyInfo, skySeed) 

322 if skySourceFootprints: 

323 key = mergedList.schema.find("merge_footprint_%s" % self.config.skyFilterName).key 

324 for foot in skySourceFootprints: 

325 s = mergedList.addNew() 

326 s.setFootprint(foot) 

327 s.set(key, True) 

328 

329 # Sort Peaks from brightest to faintest 

330 for record in mergedList: 

331 record.getFootprint().sortPeaks() 

332 self.log.info("Merged to %d sources", len(mergedList)) 

333 # Attempt to remove garbage peaks 

334 self.cullPeaks(mergedList) 

335 return Struct(outputCatalog=mergedList) 

336 

337 def cullPeaks(self, catalog): 

338 """Attempt to remove garbage peaks (mostly on the outskirts of large blends). 

339 

340 Parameters 

341 ---------- 

342 catalog : `lsst.afw.table.SourceCatalog` 

343 Source catalog. 

344 """ 

345 keys = [item.key for item in self.merged.getPeakSchema().extract("merge_peak_*").values()] 

346 assert len(keys) > 0, "Error finding flags that associate peaks with their detection bands." 

347 totalPeaks = 0 

348 culledPeaks = 0 

349 for parentSource in catalog: 

350 # Make a list copy so we can clear the attached PeakCatalog and append the ones we're keeping 

351 # to it (which is easier than deleting as we iterate). 

352 keptPeaks = parentSource.getFootprint().getPeaks() 

353 oldPeaks = list(keptPeaks) 

354 keptPeaks.clear() 

355 familySize = len(oldPeaks) 

356 totalPeaks += familySize 

357 for rank, peak in enumerate(oldPeaks): 

358 if ((rank < self.config.cullPeaks.rankSufficient) 

359 or (sum([peak.get(k) for k in keys]) >= self.config.cullPeaks.nBandsSufficient) 

360 or (rank < self.config.cullPeaks.rankConsidered 

361 and rank < self.config.cullPeaks.rankNormalizedConsidered * familySize)): 

362 keptPeaks.append(peak) 

363 else: 

364 culledPeaks += 1 

365 self.log.info("Culled %d of %d peaks", culledPeaks, totalPeaks) 

366 

367 def getSkySourceFootprints(self, mergedList, skyInfo, seed): 

368 """Return a list of Footprints of sky objects which don't overlap with anything in mergedList. 

369 

370 Parameters 

371 ---------- 

372 mergedList : `lsst.afw.table.SourceCatalog` 

373 The merged Footprints from all the input bands. 

374 skyInfo : `lsst.pipe.base.Struct` 

375 A description of the patch. 

376 seed : `int` 

377 Seed for the random number generator. 

378 """ 

379 mask = afwImage.Mask(skyInfo.patchInfo.getOuterBBox()) 

380 detected = mask.getPlaneBitMask("DETECTED") 

381 for s in mergedList: 

382 s.getFootprint().spans.setMask(mask, detected) 

383 

384 footprints = self.skyObjects.run(mask, seed) 

385 if not footprints: 

386 return footprints 

387 

388 # Need to convert the peak catalog's schema so we can set the "merge_peak_<skyFilterName>" flags 

389 schema = self.merged.getPeakSchema() 

390 mergeKey = schema.find("merge_peak_%s" % self.config.skyFilterName).key 

391 converted = [] 

392 for oldFoot in footprints: 

393 assert len(oldFoot.getPeaks()) == 1, "Should be a single peak only" 

394 peak = oldFoot.getPeaks()[0] 

395 newFoot = afwDetect.Footprint(oldFoot.spans, schema) 

396 newFoot.addPeak(peak.getFx(), peak.getFy(), peak.getPeakValue()) 

397 newFoot.getPeaks()[0].set(mergeKey, True) 

398 converted.append(newFoot) 

399 

400 return converted