Coverage for python/lsst/faro/measurement/TractTableValueMeasurement.py: 36%

59 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-12 02:56 -0700

1# This file is part of faro. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import astropy.units as u 

23 

24import lsst.pex.config as pexConfig 

25import lsst.pipe.base.connectionTypes as cT 

26from lsst.pipe.base import Struct 

27from lsst.pipe.tasks.diff_matched_tract_catalog import SourceType 

28from lsst.pipe.tasks.configurableActions import ConfigurableActionField 

29from lsst.pipe.tasks.dataFrameActions import SingleColumnAction 

30from lsst.verify import Measurement 

31from lsst.verify.tasks import MetricTask, MetricConfig, MetricConnections, MetricComputationError 

32 

33__all__ = ( 

34 "TractTableValueMeasurementConnections", 

35 "TractTableValueMeasurementConfig", 

36 "TractTableValueMeasurementTask", 

37) 

38 

39 

40class TractTableValueMeasurementConnections( 

41 MetricConnections, 

42 defaultTemplates={"package": None, "metric": None, "name_table": None}, 

43 dimensions=("tract", "skymap"), 

44): 

45 columns = cT.Input( 

46 doc="Table columns to read", 

47 name="{name_table}.columns", 

48 storageClass="DataFrameIndex", 

49 dimensions=("tract", "skymap"), 

50 ) 

51 measurement = cT.Output( 

52 name="metricvalue_{package}_{metric}", 

53 doc="The metric value computed by this task.", 

54 storageClass="MetricValue", 

55 dimensions=("tract", "skymap", "band"), 

56 multiple=True, 

57 ) 

58 table = cT.Input( 

59 doc="Table to read value from", 

60 name="{name_table}", 

61 storageClass="DataFrame", 

62 dimensions=("tract", "skymap"), 

63 deferLoad=True, 

64 ) 

65 

66 

67class TractTableValueMeasurementConfig( 

68 MetricConfig, pipelineConnections=TractTableValueMeasurementConnections 

69): 

70 """Configuration for TractTableValueMeasurementTask.""" 

71 action = ConfigurableActionField( 

72 doc="Action to compute the value with", 

73 default=SingleColumnAction, 

74 ) 

75 band_order = pexConfig.ListField( 

76 dtype=str, 

77 doc="Standard (usually wavelength-based) ordering for possible bands" 

78 " to determine standard colors", 

79 default=('u', 'g', 'r', 'i', 'z', 'y'), 

80 ) 

81 format_column = pexConfig.Field( 

82 dtype=str, 

83 doc="Format of the full column names including the band", 

84 default="{band}_{column}", 

85 ) 

86 prefixes_column = pexConfig.ListField( 

87 dtype=str, 

88 doc="Column name prefixes to ignore when applying special formatting rules", 

89 default=[f'{x.value.label}_' for x in SourceType], 

90 ) 

91 row = pexConfig.Field( 

92 dtype=int, 

93 doc="Index of the row to retrieve the value from", 

94 optional=False, 

95 ) 

96 unit = pexConfig.Field( 

97 dtype=str, 

98 doc="The astropy unit of the metric value", 

99 default='', 

100 ) 

101 

102 def _format_column(self, band: str, column: str): 

103 prefix = '' 

104 for prefix_column in self.prefixes_column: 

105 if column.startswith(prefix_column): 

106 prefix = prefix_column 

107 column = column[len(prefix):] 

108 break 

109 if column.startswith('color_'): 

110 column = f'color_{self.band_order[self.band_order.index(band)- 1]}_m_{band}_{band}_{column[6:]}' 

111 if column.startswith('flux_'): 

112 column = f'flux_{band}_{column[5:]}' 

113 elif column.startswith('mag_'): 

114 column = f'mag_{band}_{column[4:]}' 

115 return self.format_column.format(band=band, column=f'{prefix}{column}') 

116 

117 

118class TractTableValueMeasurementTask(MetricTask): 

119 """Measure a metric from a single row and combination of columns in a table.""" 

120 

121 ConfigClass = TractTableValueMeasurementConfig 

122 _DefaultName = "TractTableValueMeasurementTask" 

123 

124 def run(self, table, bands, name_metric): 

125 unit = u.Unit(self.config.unit) 

126 measurements = [None]*len(bands) 

127 columns = list(self.config.action.columns) 

128 for idx, band in enumerate(bands): 

129 row = table.iloc[[self.config.row]].rename( 

130 columns={self.config._format_column(band, column): column 

131 for column in columns} 

132 ) 

133 value = self.config.action(row).iloc[0] 

134 measurements[idx] = Measurement(name_metric, value*unit) 

135 return Struct(measurement=measurements) 

136 

137 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

138 try: 

139 inputs = butlerQC.get(inputRefs) 

140 bands = [x.dataId['band'] for x in [y for y in outputRefs][0][1]] 

141 columns_base = list(self.config.action.columns) 

142 columns_in = [] 

143 for band in bands: 

144 columns_in.extend(self.config._format_column(band, column) 

145 for column in columns_base) 

146 

147 # If columns_in contains non-existent columns, the get call will fail 

148 outputs = self.run( 

149 table=inputs['table'].get(parameters={'columns': columns_in}), 

150 bands=bands, 

151 name_metric=self.config.connections.metric, 

152 ) 

153 butlerQC.put(outputs, outputRefs) 

154 except MetricComputationError: 

155 self.log.error( 

156 "Measurement of %r failed on %s->%s", 

157 self, inputRefs, outputRefs, exc_info=True)