Coverage for python / lsst / analysis / tools / tasks / makeMetricTable.py: 17%

94 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 18:53 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "MakeMetricTableConfig", 

25 "MakeMetricTableTask", 

26) 

27 

28import lsst.pipe.base as pipeBase 

29import numpy as np 

30from astropy import units as u 

31from astropy.table import Table 

32from lsst.pex.config import ListField 

33from lsst.pipe.base import connectionTypes as ct 

34from lsst.skymap import BaseSkyMap 

35 

36from ..utils import getTractCorners 

37 

38 

39class MakeMetricTableConnections( 

40 pipeBase.PipelineTaskConnections, 

41 dimensions=(), 

42 defaultTemplates={"metricBundleName": "", "outputTableName": ""}, 

43): 

44 data = ct.Input( 

45 doc="Metric bundle to read from the butler", 

46 name="{metricBundleName}", 

47 storageClass="MetricMeasurementBundle", 

48 deferLoad=True, 

49 dimensions=(), 

50 multiple=True, 

51 ) 

52 

53 skymap = ct.Input( 

54 doc="The skymap that covers the tract that the data is from.", 

55 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

56 storageClass="SkyMap", 

57 dimensions=("skymap",), 

58 ) 

59 

60 metricTable = ct.Output( 

61 doc="Table containing metrics, one row per input metric bundle", 

62 name="{outputTableName}", 

63 storageClass="ArrowAstropy", 

64 dimensions=(), 

65 ) 

66 

67 def __init__(self, *, config=None): 

68 super().__init__(config=config) 

69 self.data = ct.Input( 

70 doc=self.data.doc, 

71 name=self.data.name, 

72 storageClass=self.data.storageClass, 

73 deferLoad=self.data.deferLoad, 

74 dimensions=frozenset(sorted(config.inputDataDimensions)), 

75 multiple=self.data.multiple, 

76 ) 

77 self.metricTable = ct.Output( 

78 doc=self.metricTable.doc, 

79 name=self.metricTable.name, 

80 storageClass=self.metricTable.storageClass, 

81 dimensions=frozenset(sorted(config.outputTableDimensions)), 

82 ) 

83 

84 assert config is not None, "Missing required config object." 

85 

86 if "tract" not in config.inputDataDimensions: 

87 del self.skymap 

88 

89 self.dimensions.update(frozenset(sorted(config.outputTableDimensions))) 

90 

91 

92class MakeMetricTableConfig( 

93 pipeBase.PipelineTaskConfig, 

94 pipelineConnections=MakeMetricTableConnections, 

95): 

96 inputDataDimensions = ListField[str]( 

97 doc="Dimensions of the input data.", 

98 default=("skymap", "tract"), 

99 optional=False, 

100 ) 

101 outputTableDimensions = ListField[str]( 

102 doc="Dimensions of the output data.", 

103 default=("skymap",), 

104 optional=False, 

105 ) 

106 dataIdFieldsToIncludeAsColumns = ListField[str]( 

107 doc="DataId fields to include as columns in the table. " 

108 "These are added in addition to the Metric names. " 

109 "At least one field must be specified.", 

110 default=("tract",), 

111 optional=False, 

112 ) 

113 

114 

115class MakeMetricTableTask(pipeBase.PipelineTask): 

116 """Turn metric bundles and combine them into a metric table.""" 

117 

118 ConfigClass = MakeMetricTableConfig 

119 _DefaultName = "makeMetricTable" 

120 

121 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

122 """Take a set of metric bundles, seperate each into its different 

123 metrics, then put the values into a table with the metric names as 

124 column headers. 

125 

126 Parameters 

127 ---------- 

128 butlerQC : `lsst.pipe.base.QuantumContext` 

129 inputRefs : `lsst.pipe.base.InputQuantizedConnection` 

130 outputRefs : `lsst.pipe.base.OutputQuantizedConnection` 

131 """ 

132 

133 inputs = butlerQC.get(inputRefs) 

134 if "skymap" in inputs: 

135 skymap = inputs["skymap"] 

136 else: 

137 skymap = None 

138 

139 # Extract the info from the dataIds that is needed 

140 # to populate the requested columns. 

141 fields = self.config.dataIdFieldsToIncludeAsColumns 

142 dataIdInfo = [] 

143 for data in inputRefs.data: 

144 dataIdInfo.append({field: data.dataId[field] for field in fields}) 

145 

146 metricBundles = [] 

147 for inputHandle in inputs["data"]: 

148 metricBundles.append(inputHandle.get()) 

149 

150 outputs = self.run(dataIdInfo, metricBundles, skymap) 

151 butlerQC.put(outputs, outputRefs) 

152 

153 def run(self, dataIdInfo, metricBundles, skymap): 

154 """Take the metric bundles and expand them out, then make a table of 

155 the information. Add tract corner information if the bundles are 

156 tract-level. 

157 

158 Parameters 

159 ---------- 

160 dataIdInfo : `list` 

161 A list of dicts that hold information extracted from the metric 

162 bundle dataIds. 

163 metricBundles : `list` of 

164 `lsst.analysis.tools.interfaces._metricMeasurementBundle.MetricMeasurementBundle` 

165 skymap : `lsst.skymap` 

166 

167 Returns 

168 ------- 

169 metricTableStruct : `pipe.base.Struct` containing `astropy.table.Table` 

170 """ 

171 

172 if len(dataIdInfo) == 0: 

173 raise pipeBase.NoWorkFound("dataIdInfo list is empty") 

174 if len(metricBundles) == 0: 

175 raise pipeBase.NoWorkFound("metricBundles list is empty") 

176 

177 metricsDict = {} 

178 metricUnits = {} 

179 

180 # Add requested info from the first dataId to the metrics dict. 

181 for key, value in dataIdInfo[0].items(): 

182 metricsDict[key] = [value] 

183 

184 # Add tract corners if inputs are at the tract-level 

185 if "tract" in self.config.inputDataDimensions: 

186 corners = getTractCorners(skymap, dataIdInfo[0]["tract"]) 

187 metricsDict["corners"] = [corners] 

188 

189 # Add the metrics and units from the first bundle to the dicts 

190 for name, metrics in metricBundles[0].items(): 

191 percentUnitUsed = False 

192 for metric in metrics: 

193 fullName = f"{name}_{metric.metric_name}" 

194 metricsDict[fullName] = [metric.quantity.value] 

195 # "Dimensionless" and percent units not allowed in Tables: 

196 if metric.quantity.unit is u.dimensionless_unscaled: 

197 continue 

198 elif metric.quantity.unit is u.pct: 

199 percentUnitUsed = True 

200 self.log.debug( 

201 "Unable to propagate astropy percent unit for metric %s " 

202 "in bundle %s to astropy table.", 

203 metric.metric_name, 

204 name, 

205 ) 

206 continue 

207 else: 

208 metricUnits[fullName] = metric.quantity.unit 

209 if percentUnitUsed: 

210 self.log.warn( 

211 "One or more metrics in the %s metric bundle uses the percent unit, " 

212 "which is not supported in astropy Tables. " 

213 "The value(s) have been propagated without the percent unit. " 

214 "Use --log-level debug to list all affected metrics.", 

215 name, 

216 ) 

217 

218 # Check if any additional columns are needed; add to dict if needed. 

219 for i, metricBundle in enumerate(metricBundles[1:]): 

220 metricRecord = [] 

221 for key, value in dataIdInfo[i + 1].items(): 

222 metricsDict[key].append(value) 

223 metricRecord.append(key) 

224 

225 if "tract" in self.config.inputDataDimensions: 

226 corners = getTractCorners(skymap, dataIdInfo[i + 1]["tract"]) 

227 metricsDict["corners"].append(corners) 

228 metricRecord.append("corners") 

229 

230 for name, metrics in metricBundle.items(): 

231 for metric in metrics: 

232 fullName = f"{name}_{metric.metric_name}" 

233 # Check if the metric already exists in the output 

234 if fullName in metricsDict.keys(): 

235 metricsDict[fullName].append(metric.quantity.value) 

236 else: 

237 values = [np.nan] * (len(metricsDict[metricRecord[0]]) - 1) 

238 values.append(metric.quantity.value) 

239 if ( 

240 metric.quantity.unit is not u.dimensionless_unscaled 

241 and metric.quantity.unit is not u.pct 

242 ): 

243 metricUnits[fullName] = metric.quantity.unit 

244 metricsDict[fullName] = values 

245 metricRecord.append(fullName) 

246 

247 # If a metric that existed in a previous bundle does 

248 # not exist for this one then add a nan 

249 for metricName in metricsDict.keys(): 

250 if metricName not in metricRecord: 

251 metricsDict[metricName].append(np.nan) 

252 

253 metricTableStruct = pipeBase.Struct(metricTable=Table(metricsDict, units=metricUnits)) 

254 return metricTableStruct