Coverage for python / lsst / meas / extensions / multiprofit / consolidate_astropy_table.py: 0%

133 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 09:28 +0000

1# This file is part of meas_extensions_multiprofit. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ( 

23 "ConsolidateAstropyTableConfigBase", 

24 "ConsolidateAstropyTableConnections", 

25 "ConsolidateAstropyTableConfig", 

26 "ConsolidateAstropyTableTask", 

27) 

28 

29from collections import defaultdict 

30 

31import astropy.table as apTab 

32import lsst.pex.config as pexConfig 

33import lsst.pipe.base as pipeBase 

34import lsst.pipe.base.connectionTypes as connectionTypes 

35import numpy as np 

36 

37from .input_config import InputConfig 

38 

39 

40class ConsolidateAstropyTableConfigBase(pexConfig.Config): 

41 """Config for ConsolidateAstropyTableTask.""" 

42 

43 inputs = pexConfig.ConfigDictField( 

44 doc="Mapping of input dataset type config by name", 

45 keytype=str, 

46 itemtype=InputConfig, 

47 default={}, 

48 ) 

49 

50 

51class ConsolidateAstropyTableConnections( 

52 # Ignore the undocumented inherited config arg in __init__ 

53 pipeBase.PipelineTaskConnections, 

54 dimensions=("tract", "skymap"), # numpydoc ignore=PR01 

55): 

56 """Connections for ConsolidateAstropyTableTask.""" 

57 

58 cat_output = connectionTypes.Output( 

59 doc="Per-tract horizontal concatenation of the input AstropyTables", 

60 name="objectAstropyTable_tract", 

61 storageClass="ArrowTable", 

62 dimensions=("tract", "skymap"), 

63 ) 

64 

65 def __init__(self, *, config: ConsolidateAstropyTableConfigBase): 

66 super().__init__(config=config) 

67 for name, config_input in config.inputs.items(): 

68 if hasattr(self, name): 

69 raise ValueError( 

70 f"{config_input=} {name=} is invalid, due to being an existing attribute" f" of {self=}" 

71 ) 

72 connection = config_input.get_connection(name) 

73 setattr(self, name, connection) 

74 

75 

76class ConsolidateAstropyTableConfig( 

77 pipeBase.PipelineTaskConfig, 

78 ConsolidateAstropyTableConfigBase, 

79 pipelineConnections=ConsolidateAstropyTableConnections, 

80): 

81 """PipelineTaskConfig for ConsolidateAstropyTableTask.""" 

82 

83 drop_duplicate_columns = pexConfig.Field[bool]( 

84 doc="Whether to drop columns from a table if they occur in a previous table." 

85 " If False, astropy will rename them with its default scheme.", 

86 default=True, 

87 ) 

88 join_type = pexConfig.ChoiceField[str]( 

89 doc="Type of join to perform in the final hstack", 

90 allowed={ 

91 "inner": "Inner join", 

92 "outer": "Outer join", 

93 "exact": "Exact join", 

94 }, 

95 default="exact", 

96 optional=False, 

97 ) 

98 validate_duplicate_columns = pexConfig.Field[bool]( 

99 doc="Whether to check that duplicate columns are identical in any table they occur in.", 

100 default=True, 

101 ) 

102 

103 

104class ConsolidateAstropyTableTask(pipeBase.PipelineTask): 

105 """Write patch-merged astropy tables to a tract-level astropy table.""" 

106 

107 _DefaultName = "consolidateAstropyTable" 

108 ConfigClass = ConsolidateAstropyTableConfig 

109 

110 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

111 inputs = butlerQC.get(inputRefs) 

112 bands_ref, patches_ref = None, None 

113 band_null, patch_null = "", -1 

114 bands_null, patches_null = {band_null}, {patch_null: None} 

115 data = dict() 

116 bands_sorted = None 

117 

118 # inputRefs are usually unsorted lists so they need to be sorted first 

119 for name, inputRef_list in inputRefs: 

120 inputConfig = self.config.inputs[name] 

121 bands, patches = set(), dict() 

122 data_name = defaultdict(dict) 

123 inputs_name = inputs[name] 

124 

125 # if it's not a list, then it's a single object 

126 if not hasattr(inputRef_list, "__len__"): 

127 inputRef_list = tuple((inputRef_list,)) 

128 inputs_name = tuple((inputs_name,)) 

129 

130 # Add every ref by band (if not multiband) 

131 for dataRef, data_in in zip(inputRef_list, inputs_name): 

132 dataId = dataRef.dataId 

133 band = dataId.band.name if not inputConfig.is_multiband else band_null 

134 

135 if inputConfig.columns is not None: 

136 columns = inputConfig.columns 

137 data_in = data_in.get(parameters={"columns": columns}) 

138 else: 

139 columns = tuple(data_in.columns) 

140 

141 if inputConfig.storageClass == "DataFrame": 

142 data_in = apTab.Table.from_pandas(data_in.reset_index(drop=False)) 

143 elif inputConfig.storageClass == "ArrowAstropy": 

144 data_in.meta = {name: data_in.meta} 

145 

146 if not inputConfig.is_multiband: 

147 columns_new = [ 

148 column if column == inputConfig.column_id else f"{band}_{column}" 

149 for column in columns 

150 ] 

151 data_in.rename_columns(columns, columns_new) 

152 if inputConfig.action is not None: 

153 data_in = inputConfig.action(data_in, datasetType=name) 

154 

155 if inputConfig.is_multipatch: 

156 patch = patch_null 

157 patches[patch] = None 

158 else: 

159 patch = dataId.patch.id 

160 patches[patch] = min(data_in[inputConfig.column_id]) 

161 data_name[patch][band] = data_in 

162 bands.add(band) 

163 

164 # Validate the bands 

165 if inputConfig.is_multiband: 

166 if bands != bands_null: 

167 raise RuntimeError(f"multiband {inputConfig=} has non-trivial {bands=}") 

168 else: 

169 if bands_ref is None: 

170 bands_ref = bands 

171 bands_sorted = tuple(band for band in sorted(bands_ref)) 

172 else: 

173 if bands != bands_ref: 

174 raise RuntimeError(f"{inputConfig=} {bands=} != {bands_ref=}") 

175 

176 # Check that every dataset has the same set of patches 

177 if inputConfig.is_multipatch: 

178 if patches != patches_null: 

179 raise RuntimeError(f"{inputConfig=} {patches=} != {patches_null=}") 

180 else: 

181 column_id = inputConfig.column_id 

182 if patches_ref is None: 

183 bands = tuple(bands) if inputConfig.is_multiband else bands_sorted 

184 for patch in patches: 

185 data_patch = data_name[patch] 

186 # Make sure any one-time operations are done once 

187 # rather than for every band 

188 added = False 

189 for band in bands: 

190 if tab := data_patch.get(band): 

191 if not added: 

192 # add a patch column to fill in later 

193 tab.add_column(np.full(len(tab), patch), name="patch", index=1) 

194 # The id column should be objectId 

195 tab.rename_column(column_id, "objectId") 

196 added = True 

197 else: 

198 del tab[column_id] 

199 patches_objid = {objid: patch for patch, objid in patches.items()} 

200 patches_ref = {patch: objid for objid, patch in sorted(patches_objid.items())} 

201 elif {patch: patches[patch] for patch in patches_ref.keys()} != patches_ref: 

202 raise RuntimeError(f"{inputConfig=} {patches=} != {patches_ref=}") 

203 else: 

204 for data_patch in data_name.values(): 

205 for tab in data_patch.values(): 

206 del tab[column_id] 

207 

208 data[name] = data_name 

209 

210 self.log.info("Concatenating %s per-patch astropy Tables", len(patches)) 

211 

212 tables_read = [] 

213 check_columns = self.config.drop_duplicate_columns or self.config.validate_duplicate_columns 

214 n_bands = len(bands_sorted) 

215 

216 for name, data_name in data.items(): 

217 config_input = self.config.inputs[name] 

218 tables = [] 

219 bands_missing = False 

220 

221 # If this is a multipatch dataset, loop over patches 

222 # Otherwise, loop over the single "null" patch 

223 for patch in patches_ref if not config_input.is_multipatch else patches_null: 

224 data_name_patch = data_name[patch] 

225 # If this is multiband, use the null band, and return an empty 

226 # list if there's no corresponding dataset 

227 if config_input.is_multiband: 

228 tables_patch = data_name_patch.get(band_null, []) 

229 else: 

230 # Get the tables (or None if it's missing) in sorted order 

231 tables_patch = [ 

232 _tab for band in bands_sorted if (_tab := data_name_patch.get(band)) is not None 

233 ] 

234 # Check if any bands are missing 

235 if not bands_missing and (len(tables_patch) != n_bands): 

236 bands_missing = True 

237 # Join only if there's something to join 

238 if tables_patch: 

239 table_patch = apTab.hstack(tables_patch, join_type="exact") 

240 tables.append(table_patch) 

241 # If there's nothing to join, presumably the task failed 

242 # stacking should handle some tasks failing but not others, but 

243 # this shouldn't be relied upon 

244 

245 table_new = ( 

246 tables[0] 

247 if (len(tables) == 1) 

248 else apTab.vstack(tables, join_type="outer" if bands_missing else "exact") 

249 ) 

250 

251 if check_columns: 

252 columns_new = set(x for x in table_new.colnames if x != config_input.join_column) 

253 for name_previous in tables_read: 

254 table_old = data[name_previous] 

255 columns_common = columns_new.intersection( 

256 x for x in table_old.colnames if x != self.config.inputs[name_previous].join_column 

257 ) 

258 for column_common in columns_common: 

259 if self.config.validate_duplicate_columns: 

260 if not np.array_equal( 

261 table_new[column_common], 

262 table_old[column_common], 

263 equal_nan=True, 

264 ): 

265 raise RuntimeError( 

266 f"Joined table column={column_common} differs between {name} and" 

267 f" {name_previous} tables" 

268 ) 

269 if self.config.drop_duplicate_columns: 

270 del table_new[column_common] 

271 

272 data[name] = table_new 

273 tables_read.append(name) 

274 

275 # This will break if all tables have config.join_column 

276 # ... but that seems unlikely. 

277 table = apTab.hstack( 

278 [data[name] for name, config in self.config.inputs.items() if config.join_column is None], 

279 join_type=self.config.join_type, 

280 ) 

281 for name, config in self.config.inputs.items(): 

282 if config.join_column: 

283 table = apTab.join( 

284 table, 

285 data[name], 

286 join_type=self.config.join_type, 

287 keys=config.join_column, 

288 ) 

289 

290 butlerQC.put(pipeBase.Struct(cat_output=table), outputRefs)