Coverage for python / lsst / analysis / tools / tasks / makeMetricTable.py: 17%

94 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-06 09:07 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "MakeMetricTableConfig", 

25 "MakeMetricTableTask", 

26) 

27 

28import numpy as np 

29from astropy import units as u 

30from astropy.table import Table 

31 

32import lsst.pipe.base as pipeBase 

33from lsst.pex.config import ListField 

34from lsst.pipe.base import connectionTypes as ct 

35from lsst.skymap import BaseSkyMap 

36 

37from ..utils import getTractCorners 

38 

39 

40class MakeMetricTableConnections( 

41 pipeBase.PipelineTaskConnections, 

42 dimensions=(), 

43 defaultTemplates={"metricBundleName": "", "outputTableName": ""}, 

44): 

45 data = ct.Input( 

46 doc="Metric bundle to read from the butler", 

47 name="{metricBundleName}", 

48 storageClass="MetricMeasurementBundle", 

49 deferLoad=True, 

50 dimensions=(), 

51 multiple=True, 

52 ) 

53 

54 skymap = ct.Input( 

55 doc="The skymap that covers the tract that the data is from.", 

56 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

57 storageClass="SkyMap", 

58 dimensions=("skymap",), 

59 ) 

60 

61 metricTable = ct.Output( 

62 doc="Table containing metrics, one row per input metric bundle", 

63 name="{outputTableName}", 

64 storageClass="ArrowAstropy", 

65 dimensions=(), 

66 ) 

67 

68 def __init__(self, *, config=None): 

69 super().__init__(config=config) 

70 self.data = ct.Input( 

71 doc=self.data.doc, 

72 name=self.data.name, 

73 storageClass=self.data.storageClass, 

74 deferLoad=self.data.deferLoad, 

75 dimensions=frozenset(sorted(config.inputDataDimensions)), 

76 multiple=self.data.multiple, 

77 ) 

78 self.metricTable = ct.Output( 

79 doc=self.metricTable.doc, 

80 name=self.metricTable.name, 

81 storageClass=self.metricTable.storageClass, 

82 dimensions=frozenset(sorted(config.outputTableDimensions)), 

83 ) 

84 

85 assert config is not None, "Missing required config object." 

86 

87 if "tract" not in config.inputDataDimensions: 

88 del self.skymap 

89 

90 self.dimensions.update(frozenset(sorted(config.outputTableDimensions))) 

91 

92 

93class MakeMetricTableConfig( 

94 pipeBase.PipelineTaskConfig, 

95 pipelineConnections=MakeMetricTableConnections, 

96): 

97 inputDataDimensions = ListField[str]( 

98 doc="Dimensions of the input data.", 

99 default=("skymap", "tract"), 

100 optional=False, 

101 ) 

102 outputTableDimensions = ListField[str]( 

103 doc="Dimensions of the output data.", 

104 default=("skymap",), 

105 optional=False, 

106 ) 

107 dataIdFieldsToIncludeAsColumns = ListField[str]( 

108 doc="DataId fields to include as columns in the table. " 

109 "These are added in addition to the Metric names. " 

110 "At least one field must be specified.", 

111 default=("tract",), 

112 optional=False, 

113 ) 

114 

115 

116class MakeMetricTableTask(pipeBase.PipelineTask): 

117 """Turn metric bundles and combine them into a metric table.""" 

118 

119 ConfigClass = MakeMetricTableConfig 

120 _DefaultName = "makeMetricTable" 

121 

122 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

123 """Take a set of metric bundles, seperate each into its different 

124 metrics, then put the values into a table with the metric names as 

125 column headers. 

126 

127 Parameters 

128 ---------- 

129 butlerQC : `lsst.pipe.base.QuantumContext` 

130 inputRefs : `lsst.pipe.base.InputQuantizedConnection` 

131 outputRefs : `lsst.pipe.base.OutputQuantizedConnection` 

132 """ 

133 

134 inputs = butlerQC.get(inputRefs) 

135 if "skymap" in inputs: 

136 skymap = inputs["skymap"] 

137 else: 

138 skymap = None 

139 

140 # Extract the info from the dataIds that is needed 

141 # to populate the requested columns. 

142 fields = self.config.dataIdFieldsToIncludeAsColumns 

143 dataIdInfo = [] 

144 for data in inputRefs.data: 

145 dataIdInfo.append({field: data.dataId[field] for field in fields}) 

146 

147 metricBundles = [] 

148 for inputHandle in inputs["data"]: 

149 metricBundles.append(inputHandle.get()) 

150 

151 outputs = self.run(dataIdInfo, metricBundles, skymap) 

152 butlerQC.put(outputs, outputRefs) 

153 

154 def run(self, dataIdInfo, metricBundles, skymap): 

155 """Take the metric bundles and expand them out, then make a table of 

156 the information. Add tract corner information if the bundles are 

157 tract-level. 

158 

159 Parameters 

160 ---------- 

161 dataIdInfo : `list` 

162 A list of dicts that hold information extracted from the metric 

163 bundle dataIds. 

164 metricBundles : `list` of 

165 `lsst.analysis.tools.interfaces._metricMeasurementBundle.MetricMeasurementBundle` 

166 skymap : `lsst.skymap` 

167 

168 Returns 

169 ------- 

170 metricTableStruct : `pipe.base.Struct` containing `astropy.table.Table` 

171 """ 

172 

173 if len(dataIdInfo) == 0: 

174 raise pipeBase.NoWorkFound("dataIdInfo list is empty") 

175 if len(metricBundles) == 0: 

176 raise pipeBase.NoWorkFound("metricBundles list is empty") 

177 

178 metricsDict = {} 

179 metricUnits = {} 

180 

181 # Add requested info from the first dataId to the metrics dict. 

182 for key, value in dataIdInfo[0].items(): 

183 metricsDict[key] = [value] 

184 

185 # Add tract corners if inputs are at the tract-level 

186 if "tract" in self.config.inputDataDimensions: 

187 corners = getTractCorners(skymap, dataIdInfo[0]["tract"]) 

188 metricsDict["corners"] = [corners] 

189 

190 # Add the metrics and units from the first bundle to the dicts 

191 for name, metrics in metricBundles[0].items(): 

192 percentUnitUsed = False 

193 for metric in metrics: 

194 fullName = f"{name}_{metric.metric_name}" 

195 metricsDict[fullName] = [metric.quantity.value] 

196 # "Dimensionless" and percent units not allowed in Tables: 

197 if metric.quantity.unit is u.dimensionless_unscaled: 

198 continue 

199 elif metric.quantity.unit is u.pct: 

200 percentUnitUsed = True 

201 self.log.debug( 

202 "Unable to propagate astropy percent unit for metric %s " 

203 "in bundle %s to astropy table.", 

204 metric.metric_name, 

205 name, 

206 ) 

207 continue 

208 else: 

209 metricUnits[fullName] = metric.quantity.unit 

210 if percentUnitUsed: 

211 self.log.warn( 

212 "One or more metrics in the %s metric bundle uses the percent unit, " 

213 "which is not supported in astropy Tables. " 

214 "The value(s) have been propagated without the percent unit. " 

215 "Use --log-level debug to list all affected metrics.", 

216 name, 

217 ) 

218 

219 # Check if any additional columns are needed; add to dict if needed. 

220 for i, metricBundle in enumerate(metricBundles[1:]): 

221 metricRecord = [] 

222 for key, value in dataIdInfo[i + 1].items(): 

223 metricsDict[key].append(value) 

224 metricRecord.append(key) 

225 

226 if "tract" in self.config.inputDataDimensions: 

227 corners = getTractCorners(skymap, dataIdInfo[i + 1]["tract"]) 

228 metricsDict["corners"].append(corners) 

229 metricRecord.append("corners") 

230 

231 for name, metrics in metricBundle.items(): 

232 for metric in metrics: 

233 fullName = f"{name}_{metric.metric_name}" 

234 # Check if the metric already exists in the output 

235 if fullName in metricsDict.keys(): 

236 metricsDict[fullName].append(metric.quantity.value) 

237 else: 

238 values = [np.nan] * (len(metricsDict[metricRecord[0]]) - 1) 

239 values.append(metric.quantity.value) 

240 if ( 

241 metric.quantity.unit is not u.dimensionless_unscaled 

242 and metric.quantity.unit is not u.pct 

243 ): 

244 metricUnits[fullName] = metric.quantity.unit 

245 metricsDict[fullName] = values 

246 metricRecord.append(fullName) 

247 

248 # If a metric that existed in a previous bundle does 

249 # not exist for this one then add a nan 

250 for metricName in metricsDict.keys(): 

251 if metricName not in metricRecord: 

252 metricsDict[metricName].append(np.nan) 

253 

254 metricTableStruct = pipeBase.Struct(metricTable=Table(metricsDict, units=metricUnits)) 

255 return metricTableStruct