Coverage for python/lsst/pipe/tasks/mergeDetections.py: 23%

153 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-09-15 03:34 -0700

1#!/usr/bin/env python 

2# 

3# LSST Data Management System 

4# Copyright 2008-2015 AURA/LSST. 

5# 

6# This product includes software developed by the 

7# LSST Project (http://www.lsst.org/). 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the LSST License Statement and 

20# the GNU General Public License along with this program. If not, 

21# see <https://www.lsstcorp.org/LegalNotices/>. 

22# 

23 

24import numpy as np 

25from numpy.lib.recfunctions import rec_join 

26import warnings 

27 

28from .multiBandUtils import CullPeaksConfig 

29 

30import lsst.afw.detection as afwDetect 

31import lsst.afw.image as afwImage 

32import lsst.afw.table as afwTable 

33 

34from lsst.meas.algorithms import SkyObjectsTask 

35from lsst.skymap import BaseSkyMap 

36from lsst.pex.config import Config, Field, ListField, ConfigurableField, ConfigField 

37from lsst.pipe.base import (PipelineTask, PipelineTaskConfig, Struct, 

38 PipelineTaskConnections) 

39import lsst.pipe.base.connectionTypes as cT 

40from lsst.obs.base import ExposureIdInfo 

41 

42 

43def matchCatalogsExact(catalog1, catalog2, patch1=None, patch2=None): 

44 """Match two catalogs derived from the same mergeDet catalog 

45 

46 When testing downstream features, like deblending methods/parameters 

47 and measurement algorithms/parameters, it is useful to to compare 

48 the same sources in two catalogs. In most cases this must be done 

49 by matching on either RA/DEC or XY positions, which occassionally 

50 will mismatch one source with another. 

51 

52 For a more robust solution, as long as the downstream catalog is 

53 derived from the same mergeDet catalog, exact source matching 

54 can be done via the unique ``(parent, deblend_peakID)`` 

55 combination. So this function performs this exact matching for 

56 all sources both catalogs. 

57 

58 Parameters 

59 ---------- 

60 catalog1, catalog2 : `lsst.afw.table.SourceCatalog` 

61 The two catalogs to merge 

62 

63 patch1, patch2 : array of int 

64 Patch for each row, converted into an integer. 

65 

66 Returns 

67 ------- 

68 result: list of `lsst.afw.table.SourceMatch` 

69 List of matches for each source (using an inner join). 

70 """ 

71 # Only match the individual sources, the parents will 

72 # already be matched by the mergeDet catalog 

73 sidx1 = catalog1["parent"] != 0 

74 sidx2 = catalog2["parent"] != 0 

75 

76 # Create the keys used to merge the catalogs 

77 parents1 = np.array(catalog1["parent"][sidx1]) 

78 peaks1 = np.array(catalog1["deblend_peakId"][sidx1]) 

79 index1 = np.arange(len(catalog1))[sidx1] 

80 parents2 = np.array(catalog2["parent"][sidx2]) 

81 peaks2 = np.array(catalog2["deblend_peakId"][sidx2]) 

82 index2 = np.arange(len(catalog2))[sidx2] 

83 

84 if patch1 is not None: 

85 if patch2 is None: 

86 msg = ("If the catalogs are from different patches then patch1 and patch2 must be specified" 

87 ", got {} and {}").format(patch1, patch2) 

88 raise ValueError(msg) 

89 patch1 = patch1[sidx1] 

90 patch2 = patch2[sidx2] 

91 

92 key1 = np.rec.array((parents1, peaks1, patch1, index1), 

93 dtype=[('parent', np.int64), ('peakId', np.int32), 

94 ("patch", patch1.dtype), ("index", np.int32)]) 

95 key2 = np.rec.array((parents2, peaks2, patch2, index2), 

96 dtype=[('parent', np.int64), ('peakId', np.int32), 

97 ("patch", patch2.dtype), ("index", np.int32)]) 

98 matchColumns = ("parent", "peakId", "patch") 

99 else: 

100 key1 = np.rec.array((parents1, peaks1, index1), 

101 dtype=[('parent', np.int64), ('peakId', np.int32), ("index", np.int32)]) 

102 key2 = np.rec.array((parents2, peaks2, index2), 

103 dtype=[('parent', np.int64), ('peakId', np.int32), ("index", np.int32)]) 

104 matchColumns = ("parent", "peakId") 

105 # Match the two keys. 

106 # This line performs an inner join on the structured 

107 # arrays `key1` and `key2`, which stores their indices 

108 # as columns in a structured array. 

109 matched = rec_join(matchColumns, key1, key2, jointype="inner") 

110 

111 # Create the full index for both catalogs 

112 indices1 = matched["index1"] 

113 indices2 = matched["index2"] 

114 

115 # Re-index the resulting catalogs 

116 matches = [ 

117 afwTable.SourceMatch(catalog1[int(i1)], catalog2[int(i2)], 0.0) 

118 for i1, i2 in zip(indices1, indices2) 

119 ] 

120 

121 return matches 

122 

123 

124class MergeDetectionsConnections(PipelineTaskConnections, 

125 dimensions=("tract", "patch", "skymap"), 

126 defaultTemplates={"inputCoaddName": 'deep', "outputCoaddName": "deep"}): 

127 schema = cT.InitInput( 

128 doc="Schema of the input detection catalog", 

129 name="{inputCoaddName}Coadd_det_schema", 

130 storageClass="SourceCatalog" 

131 ) 

132 

133 outputSchema = cT.InitOutput( 

134 doc="Schema of the merged detection catalog", 

135 name="{outputCoaddName}Coadd_mergeDet_schema", 

136 storageClass="SourceCatalog" 

137 ) 

138 

139 outputPeakSchema = cT.InitOutput( 

140 doc="Output schema of the Footprint peak catalog", 

141 name="{outputCoaddName}Coadd_peak_schema", 

142 storageClass="PeakCatalog" 

143 ) 

144 

145 catalogs = cT.Input( 

146 doc="Detection Catalogs to be merged", 

147 name="{inputCoaddName}Coadd_det", 

148 storageClass="SourceCatalog", 

149 dimensions=("tract", "patch", "skymap", "band"), 

150 multiple=True 

151 ) 

152 

153 skyMap = cT.Input( 

154 doc="SkyMap to be used in merging", 

155 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

156 storageClass="SkyMap", 

157 dimensions=("skymap",), 

158 ) 

159 

160 outputCatalog = cT.Output( 

161 doc="Merged Detection catalog", 

162 name="{outputCoaddName}Coadd_mergeDet", 

163 storageClass="SourceCatalog", 

164 dimensions=("tract", "patch", "skymap"), 

165 ) 

166 

167 

168class MergeDetectionsConfig(PipelineTaskConfig, pipelineConnections=MergeDetectionsConnections): 

169 """! 

170 @anchor MergeDetectionsConfig_ 

171 

172 @brief Configuration parameters for the MergeDetectionsTask. 

173 """ 

174 minNewPeak = Field(dtype=float, default=1, 

175 doc="Minimum distance from closest peak to create a new one (in arcsec).") 

176 

177 maxSamePeak = Field(dtype=float, default=0.3, 

178 doc="When adding new catalogs to the merge, all peaks less than this distance " 

179 " (in arcsec) to an existing peak will be flagged as detected in that catalog.") 

180 cullPeaks = ConfigField(dtype=CullPeaksConfig, doc="Configuration for how to cull peaks.") 

181 

182 skyFilterName = Field(dtype=str, default="sky", 

183 doc="Name of `filter' used to label sky objects (e.g. flag merge_peak_sky is set)\n" 

184 "(N.b. should be in MergeMeasurementsConfig.pseudoFilterList)") 

185 skyObjects = ConfigurableField(target=SkyObjectsTask, doc="Generate sky objects") 

186 priorityList = ListField(dtype=str, default=[], 

187 doc="Priority-ordered list of filter bands for the merge.") 

188 coaddName = Field(dtype=str, default="deep", doc="Name of coadd") 

189 

190 def setDefaults(self): 

191 Config.setDefaults(self) 

192 self.skyObjects.avoidMask = ["DETECTED"] # Nothing else is available in our custom mask 

193 

194 def validate(self): 

195 super().validate() 

196 if len(self.priorityList) == 0: 

197 raise RuntimeError("No priority list provided") 

198 

199 

200class MergeDetectionsTask(PipelineTask): 

201 """Task to merge coadd tetections from multiple bands. 

202 

203 Parameters 

204 ---------- 

205 butler : `None` 

206 Compatibility parameter. Should always be `None`. 

207 schema : `lsst.afw.table.Schema`, optional 

208 The schema of the detection catalogs used as input to this task. 

209 initInputs : `dict`, optional 

210 Dictionary that can contain a key ``schema`` containing the 

211 input schema. If present will override the value of ``schema``. 

212 """ 

213 ConfigClass = MergeDetectionsConfig 

214 _DefaultName = "mergeCoaddDetections" 

215 

216 def __init__(self, butler=None, schema=None, initInputs=None, **kwargs): 

217 super().__init__(**kwargs) 

218 

219 if butler is not None: 

220 warnings.warn("The 'butler' parameter is no longer used and can be safely removed.", 

221 category=FutureWarning, stacklevel=2) 

222 butler = None 

223 

224 if initInputs is not None: 

225 schema = initInputs['schema'].schema 

226 

227 if schema is None: 

228 raise ValueError("No input schema or initInputs['schema'] provided.") 

229 

230 self.schema = schema 

231 

232 self.makeSubtask("skyObjects") 

233 

234 filterNames = list(self.config.priorityList) 

235 filterNames.append(self.config.skyFilterName) 

236 self.merged = afwDetect.FootprintMergeList(self.schema, filterNames) 

237 self.outputSchema = afwTable.SourceCatalog(self.schema) 

238 self.outputPeakSchema = afwDetect.PeakCatalog(self.merged.getPeakSchema()) 

239 

240 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

241 inputs = butlerQC.get(inputRefs) 

242 exposureIdInfo = ExposureIdInfo.fromDataId(butlerQC.quantum.dataId, "tract_patch") 

243 inputs["skySeed"] = exposureIdInfo.expId 

244 inputs["idFactory"] = exposureIdInfo.makeSourceIdFactory() 

245 catalogDict = {ref.dataId['band']: cat for ref, cat in zip(inputRefs.catalogs, 

246 inputs['catalogs'])} 

247 inputs['catalogs'] = catalogDict 

248 skyMap = inputs.pop('skyMap') 

249 # Can use the first dataId to find the tract and patch being worked on 

250 tractNumber = inputRefs.catalogs[0].dataId['tract'] 

251 tractInfo = skyMap[tractNumber] 

252 patchInfo = tractInfo.getPatchInfo(inputRefs.catalogs[0].dataId['patch']) 

253 skyInfo = Struct( 

254 skyMap=skyMap, 

255 tractInfo=tractInfo, 

256 patchInfo=patchInfo, 

257 wcs=tractInfo.getWcs(), 

258 bbox=patchInfo.getOuterBBox() 

259 ) 

260 inputs['skyInfo'] = skyInfo 

261 

262 outputs = self.run(**inputs) 

263 butlerQC.put(outputs, outputRefs) 

264 

265 def run(self, catalogs, skyInfo, idFactory, skySeed): 

266 r"""! 

267 @brief Merge multiple catalogs. 

268 

269 After ordering the catalogs and filters in priority order, 

270 @ref getMergedSourceCatalog of the @ref FootprintMergeList_ "FootprintMergeList" created by 

271 @ref \_\_init\_\_ is used to perform the actual merging. Finally, @ref cullPeaks is used to remove 

272 garbage peaks detected around bright objects. 

273 

274 @param[in] catalogs 

275 @param[in] patchRef 

276 @param[out] mergedList 

277 """ 

278 

279 # Convert distance to tract coordinate 

280 tractWcs = skyInfo.wcs 

281 peakDistance = self.config.minNewPeak / tractWcs.getPixelScale().asArcseconds() 

282 samePeakDistance = self.config.maxSamePeak / tractWcs.getPixelScale().asArcseconds() 

283 

284 # Put catalogs, filters in priority order 

285 orderedCatalogs = [catalogs[band] for band in self.config.priorityList if band in catalogs.keys()] 

286 orderedBands = [band for band in self.config.priorityList if band in catalogs.keys()] 

287 

288 mergedList = self.merged.getMergedSourceCatalog(orderedCatalogs, orderedBands, peakDistance, 

289 self.schema, idFactory, 

290 samePeakDistance) 

291 

292 # 

293 # Add extra sources that correspond to blank sky 

294 # 

295 skySourceFootprints = self.getSkySourceFootprints(mergedList, skyInfo, skySeed) 

296 if skySourceFootprints: 

297 key = mergedList.schema.find("merge_footprint_%s" % self.config.skyFilterName).key 

298 for foot in skySourceFootprints: 

299 s = mergedList.addNew() 

300 s.setFootprint(foot) 

301 s.set(key, True) 

302 

303 # Sort Peaks from brightest to faintest 

304 for record in mergedList: 

305 record.getFootprint().sortPeaks() 

306 self.log.info("Merged to %d sources", len(mergedList)) 

307 # Attempt to remove garbage peaks 

308 self.cullPeaks(mergedList) 

309 return Struct(outputCatalog=mergedList) 

310 

311 def cullPeaks(self, catalog): 

312 """! 

313 @brief Attempt to remove garbage peaks (mostly on the outskirts of large blends). 

314 

315 @param[in] catalog Source catalog 

316 """ 

317 keys = [item.key for item in self.merged.getPeakSchema().extract("merge_peak_*").values()] 

318 assert len(keys) > 0, "Error finding flags that associate peaks with their detection bands." 

319 totalPeaks = 0 

320 culledPeaks = 0 

321 for parentSource in catalog: 

322 # Make a list copy so we can clear the attached PeakCatalog and append the ones we're keeping 

323 # to it (which is easier than deleting as we iterate). 

324 keptPeaks = parentSource.getFootprint().getPeaks() 

325 oldPeaks = list(keptPeaks) 

326 keptPeaks.clear() 

327 familySize = len(oldPeaks) 

328 totalPeaks += familySize 

329 for rank, peak in enumerate(oldPeaks): 

330 if ((rank < self.config.cullPeaks.rankSufficient) 

331 or (sum([peak.get(k) for k in keys]) >= self.config.cullPeaks.nBandsSufficient) 

332 or (rank < self.config.cullPeaks.rankConsidered 

333 and rank < self.config.cullPeaks.rankNormalizedConsidered * familySize)): 

334 keptPeaks.append(peak) 

335 else: 

336 culledPeaks += 1 

337 self.log.info("Culled %d of %d peaks", culledPeaks, totalPeaks) 

338 

339 def getSchemaCatalogs(self): 

340 """! 

341 Return a dict of empty catalogs for each catalog dataset produced by this task. 

342 

343 @param[out] dictionary of empty catalogs 

344 """ 

345 mergeDet = afwTable.SourceCatalog(self.schema) 

346 peak = afwDetect.PeakCatalog(self.merged.getPeakSchema()) 

347 return {self.config.coaddName + "Coadd_mergeDet": mergeDet, 

348 self.config.coaddName + "Coadd_peak": peak} 

349 

350 def getSkySourceFootprints(self, mergedList, skyInfo, seed): 

351 """! 

352 @brief Return a list of Footprints of sky objects which don't overlap with anything in mergedList 

353 

354 @param mergedList The merged Footprints from all the input bands 

355 @param skyInfo A description of the patch 

356 @param seed Seed for the random number generator 

357 """ 

358 mask = afwImage.Mask(skyInfo.patchInfo.getOuterBBox()) 

359 detected = mask.getPlaneBitMask("DETECTED") 

360 for s in mergedList: 

361 s.getFootprint().spans.setMask(mask, detected) 

362 

363 footprints = self.skyObjects.run(mask, seed) 

364 if not footprints: 

365 return footprints 

366 

367 # Need to convert the peak catalog's schema so we can set the "merge_peak_<skyFilterName>" flags 

368 schema = self.merged.getPeakSchema() 

369 mergeKey = schema.find("merge_peak_%s" % self.config.skyFilterName).key 

370 converted = [] 

371 for oldFoot in footprints: 

372 assert len(oldFoot.getPeaks()) == 1, "Should be a single peak only" 

373 peak = oldFoot.getPeaks()[0] 

374 newFoot = afwDetect.Footprint(oldFoot.spans, schema) 

375 newFoot.addPeak(peak.getFx(), peak.getFy(), peak.getPeakValue()) 

376 newFoot.getPeaks()[0].set(mergeKey, True) 

377 converted.append(newFoot) 

378 

379 return converted