Coverage for python/lsst/pipe/tasks/mergeMeasurements.py: 15%

131 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-12 01:56 -0700

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["MergeMeasurementsConfig", "MergeMeasurementsTask"] 

23 

24import numpy 

25 

26import lsst.afw.table as afwTable 

27import lsst.pex.config as pexConfig 

28import lsst.pipe.base as pipeBase 

29 

30from lsst.pipe.base import PipelineTaskConnections, PipelineTaskConfig 

31import lsst.pipe.base.connectionTypes as cT 

32 

33 

34class MergeMeasurementsConnections(PipelineTaskConnections, 

35 dimensions=("skymap", "tract", "patch"), 

36 defaultTemplates={"inputCoaddName": "deep", 

37 "outputCoaddName": "deep"}): 

38 inputSchema = cT.InitInput( 

39 doc="Schema for the output merged measurement catalog.", 

40 name="{inputCoaddName}Coadd_meas_schema", 

41 storageClass="SourceCatalog", 

42 ) 

43 outputSchema = cT.InitOutput( 

44 doc="Schema for the output merged measurement catalog.", 

45 name="{outputCoaddName}Coadd_ref_schema", 

46 storageClass="SourceCatalog", 

47 ) 

48 catalogs = cT.Input( 

49 doc="Input catalogs to merge.", 

50 name="{inputCoaddName}Coadd_meas", 

51 multiple=True, 

52 storageClass="SourceCatalog", 

53 dimensions=["band", "skymap", "tract", "patch"], 

54 ) 

55 mergedCatalog = cT.Output( 

56 doc="Output merged catalog.", 

57 name="{outputCoaddName}Coadd_ref", 

58 storageClass="SourceCatalog", 

59 dimensions=["skymap", "tract", "patch"], 

60 ) 

61 

62 

63class MergeMeasurementsConfig(PipelineTaskConfig, pipelineConnections=MergeMeasurementsConnections): 

64 """Configuration parameters for the MergeMeasurementsTask. 

65 """ 

66 pseudoFilterList = pexConfig.ListField( 

67 dtype=str, 

68 default=["sky"], 

69 doc="Names of filters which may have no associated detection\n" 

70 "(N.b. should include MergeDetectionsConfig.skyFilterName)" 

71 ) 

72 snName = pexConfig.Field( 

73 dtype=str, 

74 default="base_PsfFlux", 

75 doc="Name of flux measurement for calculating the S/N when choosing the reference band." 

76 ) 

77 minSN = pexConfig.Field( 

78 dtype=float, 

79 default=10., 

80 doc="If the S/N from the priority band is below this value (and the S/N " 

81 "is larger than minSNDiff compared to the priority band), use the band with " 

82 "the largest S/N as the reference band." 

83 ) 

84 minSNDiff = pexConfig.Field( 

85 dtype=float, 

86 default=3., 

87 doc="If the difference in S/N between another band and the priority band is larger " 

88 "than this value (and the S/N in the priority band is less than minSN) " 

89 "use the band with the largest S/N as the reference band" 

90 ) 

91 flags = pexConfig.ListField( 

92 dtype=str, 

93 doc="Require that these flags, if available, are not set", 

94 default=["base_PixelFlags_flag_interpolatedCenter", "base_PsfFlux_flag", 

95 "ext_photometryKron_KronFlux_flag", "modelfit_CModel_flag", ] 

96 ) 

97 priorityList = pexConfig.ListField( 

98 dtype=str, 

99 default=[], 

100 doc="Priority-ordered list of filter bands for the merge." 

101 ) 

102 coaddName = pexConfig.Field( 

103 dtype=str, 

104 default="deep", 

105 doc="Name of coadd" 

106 ) 

107 

108 def validate(self): 

109 super().validate() 

110 if len(self.priorityList) == 0: 

111 raise RuntimeError("No priority list provided") 

112 

113 

114class MergeMeasurementsTask(pipeBase.PipelineTask): 

115 """Merge measurements from multiple bands. 

116 

117 Combines consistent (i.e. with the same peaks and footprints) catalogs of 

118 sources from multiple filter bands to construct a unified catalog that is 

119 suitable for driving forced photometry. Every source is required to have 

120 centroid, shape and flux measurements in each band. 

121 

122 MergeMeasurementsTask is meant to be run after deblending & measuring 

123 sources in every band. The purpose of the task is to generate a catalog of 

124 sources suitable for driving forced photometry in coadds and individual 

125 exposures. 

126 

127 Parameters 

128 ---------- 

129 schema : `lsst.afw.table.Schema`, optional 

130 The schema of the detection catalogs used as input to this task. 

131 initInputs : `dict`, optional 

132 Dictionary that can contain a key ``inputSchema`` containing the 

133 input schema. If present will override the value of ``schema``. 

134 **kwargs 

135 Additional keyword arguments. 

136 """ 

137 

138 _DefaultName = "mergeCoaddMeasurements" 

139 ConfigClass = MergeMeasurementsConfig 

140 

141 inputDataset = "meas" 

142 outputDataset = "ref" 

143 

144 def __init__(self, schema=None, initInputs=None, **kwargs): 

145 super().__init__(**kwargs) 

146 

147 if initInputs is not None: 

148 schema = initInputs['inputSchema'].schema 

149 

150 if schema is None: 

151 raise ValueError("No input schema or initInputs['inputSchema'] provided.") 

152 

153 inputSchema = schema 

154 

155 self.schemaMapper = afwTable.SchemaMapper(inputSchema, True) 

156 self.schemaMapper.addMinimalSchema(inputSchema, True) 

157 self.instFluxKey = inputSchema.find(self.config.snName + "_instFlux").getKey() 

158 self.instFluxErrKey = inputSchema.find(self.config.snName + "_instFluxErr").getKey() 

159 self.fluxFlagKey = inputSchema.find(self.config.snName + "_flag").getKey() 

160 

161 self.flagKeys = {} 

162 for band in self.config.priorityList: 

163 outputKey = self.schemaMapper.editOutputSchema().addField( 

164 "merge_measurement_%s" % band, 

165 type="Flag", 

166 doc="Flag field set if the measurements here are from the %s filter" % band 

167 ) 

168 peakKey = inputSchema.find("merge_peak_%s" % band).key 

169 footprintKey = inputSchema.find("merge_footprint_%s" % band).key 

170 self.flagKeys[band] = pipeBase.Struct(peak=peakKey, footprint=footprintKey, output=outputKey) 

171 self.schema = self.schemaMapper.getOutputSchema() 

172 

173 self.pseudoFilterKeys = [] 

174 for filt in self.config.pseudoFilterList: 

175 try: 

176 self.pseudoFilterKeys.append(self.schema.find("merge_peak_%s" % filt).getKey()) 

177 except Exception as e: 

178 self.log.warning("merge_peak is not set for pseudo-filter %s: %s", filt, e) 

179 

180 self.badFlags = {} 

181 for flag in self.config.flags: 

182 try: 

183 self.badFlags[flag] = self.schema.find(flag).getKey() 

184 except KeyError as exc: 

185 self.log.warning("Can't find flag %s in schema: %s", flag, exc) 

186 self.outputSchema = afwTable.SourceCatalog(self.schema) 

187 

188 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

189 inputs = butlerQC.get(inputRefs) 

190 dataIds = (ref.dataId for ref in inputRefs.catalogs) 

191 catalogDict = {dataId['band']: cat for dataId, cat in zip(dataIds, inputs['catalogs'])} 

192 inputs['catalogs'] = catalogDict 

193 outputs = self.run(**inputs) 

194 butlerQC.put(outputs, outputRefs) 

195 

196 def run(self, catalogs): 

197 """Merge measurement catalogs to create a single reference catalog for forced photometry. 

198 

199 Parameters 

200 ---------- 

201 catalogs : `lsst.afw.table.SourceCatalog` 

202 Catalogs to be merged. 

203 

204 Raises 

205 ------ 

206 ValueError 

207 Raised if no catalog records were found; 

208 if there is no valid reference for the input record ID; 

209 or if there is a mismatch between catalog sizes. 

210 

211 Notes 

212 ----- 

213 For parent sources, we choose the first band in config.priorityList for which the 

214 merge_footprint flag for that band is is True. 

215 

216 For child sources, the logic is the same, except that we use the merge_peak flags. 

217 """ 

218 # Put catalogs, filters in priority order 

219 orderedCatalogs = [catalogs[band] for band in self.config.priorityList if band in catalogs.keys()] 

220 orderedKeys = [self.flagKeys[band] for band in self.config.priorityList if band in catalogs.keys()] 

221 

222 mergedCatalog = afwTable.SourceCatalog(self.schema) 

223 mergedCatalog.reserve(len(orderedCatalogs[0])) 

224 

225 idKey = orderedCatalogs[0].table.getIdKey() 

226 for catalog in orderedCatalogs[1:]: 

227 if numpy.any(orderedCatalogs[0].get(idKey) != catalog.get(idKey)): 

228 raise ValueError("Error in inputs to MergeCoaddMeasurements: source IDs do not match") 

229 

230 # This first zip iterates over all the catalogs simultaneously, yielding a sequence of one 

231 # record for each band, in priority order. 

232 for orderedRecords in zip(*orderedCatalogs): 

233 

234 maxSNRecord = None 

235 maxSNFlagKeys = None 

236 maxSN = 0. 

237 priorityRecord = None 

238 priorityFlagKeys = None 

239 prioritySN = 0. 

240 hasPseudoFilter = False 

241 

242 # Now we iterate over those record-band pairs, keeping track of the priority and the 

243 # largest S/N band. 

244 for inputRecord, flagKeys in zip(orderedRecords, orderedKeys): 

245 parent = (inputRecord.getParent() == 0 and inputRecord.get(flagKeys.footprint)) 

246 child = (inputRecord.getParent() != 0 and inputRecord.get(flagKeys.peak)) 

247 

248 if not (parent or child): 

249 for pseudoFilterKey in self.pseudoFilterKeys: 

250 if inputRecord.get(pseudoFilterKey): 

251 hasPseudoFilter = True 

252 priorityRecord = inputRecord 

253 priorityFlagKeys = flagKeys 

254 break 

255 if hasPseudoFilter: 

256 break 

257 

258 isBad = ( 

259 any(inputRecord.get(flag) for flag in self.badFlags) 

260 or inputRecord["deblend_dataCoverage"] == 0 

261 or inputRecord.get(self.fluxFlagKey) 

262 or inputRecord.get(self.instFluxErrKey) == 0 

263 ) 

264 if isBad: 

265 sn = 0. 

266 else: 

267 sn = inputRecord.get(self.instFluxKey)/inputRecord.get(self.instFluxErrKey) 

268 if numpy.isnan(sn) or sn < 0.: 

269 sn = 0. 

270 if (parent or child) and priorityRecord is None: 

271 priorityRecord = inputRecord 

272 priorityFlagKeys = flagKeys 

273 prioritySN = sn 

274 if sn > maxSN: 

275 maxSNRecord = inputRecord 

276 maxSNFlagKeys = flagKeys 

277 maxSN = sn 

278 

279 # If the priority band has a low S/N we would like to choose the band with the highest S/N as 

280 # the reference band instead. However, we only want to choose the highest S/N band if it is 

281 # significantly better than the priority band. Therefore, to choose a band other than the 

282 # priority, we require that the priority S/N is below the minimum threshold and that the 

283 # difference between the priority and highest S/N is larger than the difference threshold. 

284 # 

285 # For pseudo code objects we always choose the first band in the priority list. 

286 bestRecord = None 

287 bestFlagKeys = None 

288 if hasPseudoFilter: 

289 bestRecord = priorityRecord 

290 bestFlagKeys = priorityFlagKeys 

291 elif (prioritySN < self.config.minSN and (maxSN - prioritySN) > self.config.minSNDiff 

292 and maxSNRecord is not None): 

293 bestRecord = maxSNRecord 

294 bestFlagKeys = maxSNFlagKeys 

295 elif priorityRecord is not None: 

296 bestRecord = priorityRecord 

297 bestFlagKeys = priorityFlagKeys 

298 

299 if bestRecord is not None and bestFlagKeys is not None: 

300 outputRecord = mergedCatalog.addNew() 

301 outputRecord.assign(bestRecord, self.schemaMapper) 

302 outputRecord.set(bestFlagKeys.output, True) 

303 else: # if we didn't find any records 

304 raise ValueError("Error in inputs to MergeCoaddMeasurements: no valid reference for %s" % 

305 inputRecord.getId()) 

306 

307 # more checking for sane inputs, since zip silently iterates over the smallest sequence 

308 for inputCatalog in orderedCatalogs: 

309 if len(mergedCatalog) != len(inputCatalog): 

310 raise ValueError("Mismatch between catalog sizes: %s != %s" % 

311 (len(mergedCatalog), len(orderedCatalogs))) 

312 

313 return pipeBase.Struct( 

314 mergedCatalog=mergedCatalog 

315 )