Coverage for python / lsst / meas / extensions / multiprofit / consolidate_astropy_table.py: 0%
133 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 09:21 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 09:21 +0000
1# This file is part of meas_extensions_multiprofit.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = (
23 "ConsolidateAstropyTableConfigBase",
24 "ConsolidateAstropyTableConnections",
25 "ConsolidateAstropyTableConfig",
26 "ConsolidateAstropyTableTask",
27)
29from collections import defaultdict
31import astropy.table as apTab
32import lsst.pex.config as pexConfig
33import lsst.pipe.base as pipeBase
34import lsst.pipe.base.connectionTypes as connectionTypes
35import numpy as np
37from .input_config import InputConfig
40class ConsolidateAstropyTableConfigBase(pexConfig.Config):
41 """Config for ConsolidateAstropyTableTask."""
43 inputs = pexConfig.ConfigDictField(
44 doc="Mapping of input dataset type config by name",
45 keytype=str,
46 itemtype=InputConfig,
47 default={},
48 )
51class ConsolidateAstropyTableConnections(
52 # Ignore the undocumented inherited config arg in __init__
53 pipeBase.PipelineTaskConnections,
54 dimensions=("tract", "skymap"), # numpydoc ignore=PR01
55):
56 """Connections for ConsolidateAstropyTableTask."""
58 cat_output = connectionTypes.Output(
59 doc="Per-tract horizontal concatenation of the input AstropyTables",
60 name="objectAstropyTable_tract",
61 storageClass="ArrowTable",
62 dimensions=("tract", "skymap"),
63 )
65 def __init__(self, *, config: ConsolidateAstropyTableConfigBase):
66 super().__init__(config=config)
67 for name, config_input in config.inputs.items():
68 if hasattr(self, name):
69 raise ValueError(
70 f"{config_input=} {name=} is invalid, due to being an existing attribute" f" of {self=}"
71 )
72 connection = config_input.get_connection(name)
73 setattr(self, name, connection)
76class ConsolidateAstropyTableConfig(
77 pipeBase.PipelineTaskConfig,
78 ConsolidateAstropyTableConfigBase,
79 pipelineConnections=ConsolidateAstropyTableConnections,
80):
81 """PipelineTaskConfig for ConsolidateAstropyTableTask."""
83 drop_duplicate_columns = pexConfig.Field[bool](
84 doc="Whether to drop columns from a table if they occur in a previous table."
85 " If False, astropy will rename them with its default scheme.",
86 default=True,
87 )
88 join_type = pexConfig.ChoiceField[str](
89 doc="Type of join to perform in the final hstack",
90 allowed={
91 "inner": "Inner join",
92 "outer": "Outer join",
93 "exact": "Exact join",
94 },
95 default="exact",
96 optional=False,
97 )
98 validate_duplicate_columns = pexConfig.Field[bool](
99 doc="Whether to check that duplicate columns are identical in any table they occur in.",
100 default=True,
101 )
104class ConsolidateAstropyTableTask(pipeBase.PipelineTask):
105 """Write patch-merged astropy tables to a tract-level astropy table."""
107 _DefaultName = "consolidateAstropyTable"
108 ConfigClass = ConsolidateAstropyTableConfig
110 def runQuantum(self, butlerQC, inputRefs, outputRefs):
111 inputs = butlerQC.get(inputRefs)
112 bands_ref, patches_ref = None, None
113 band_null, patch_null = "", -1
114 bands_null, patches_null = {band_null}, {patch_null: None}
115 data = dict()
116 bands_sorted = None
118 # inputRefs are usually unsorted lists so they need to be sorted first
119 for name, inputRef_list in inputRefs:
120 inputConfig = self.config.inputs[name]
121 bands, patches = set(), dict()
122 data_name = defaultdict(dict)
123 inputs_name = inputs[name]
125 # if it's not a list, then it's a single object
126 if not hasattr(inputRef_list, "__len__"):
127 inputRef_list = tuple((inputRef_list,))
128 inputs_name = tuple((inputs_name,))
130 # Add every ref by band (if not multiband)
131 for dataRef, data_in in zip(inputRef_list, inputs_name):
132 dataId = dataRef.dataId
133 band = dataId.band.name if not inputConfig.is_multiband else band_null
135 if inputConfig.columns is not None:
136 columns = inputConfig.columns
137 data_in = data_in.get(parameters={"columns": columns})
138 else:
139 columns = tuple(data_in.columns)
141 if inputConfig.storageClass == "DataFrame":
142 data_in = apTab.Table.from_pandas(data_in.reset_index(drop=False))
143 elif inputConfig.storageClass == "ArrowAstropy":
144 data_in.meta = {name: data_in.meta}
146 if not inputConfig.is_multiband:
147 columns_new = [
148 column if column == inputConfig.column_id else f"{band}_{column}"
149 for column in columns
150 ]
151 data_in.rename_columns(columns, columns_new)
152 if inputConfig.action is not None:
153 data_in = inputConfig.action(data_in, datasetType=name)
155 if inputConfig.is_multipatch:
156 patch = patch_null
157 patches[patch] = None
158 else:
159 patch = dataId.patch.id
160 patches[patch] = min(data_in[inputConfig.column_id])
161 data_name[patch][band] = data_in
162 bands.add(band)
164 # Validate the bands
165 if inputConfig.is_multiband:
166 if bands != bands_null:
167 raise RuntimeError(f"multiband {inputConfig=} has non-trivial {bands=}")
168 else:
169 if bands_ref is None:
170 bands_ref = bands
171 bands_sorted = tuple(band for band in sorted(bands_ref))
172 else:
173 if bands != bands_ref:
174 raise RuntimeError(f"{inputConfig=} {bands=} != {bands_ref=}")
176 # Check that every dataset has the same set of patches
177 if inputConfig.is_multipatch:
178 if patches != patches_null:
179 raise RuntimeError(f"{inputConfig=} {patches=} != {patches_null=}")
180 else:
181 column_id = inputConfig.column_id
182 if patches_ref is None:
183 bands = tuple(bands) if inputConfig.is_multiband else bands_sorted
184 for patch in patches:
185 data_patch = data_name[patch]
186 # Make sure any one-time operations are done once
187 # rather than for every band
188 added = False
189 for band in bands:
190 if tab := data_patch.get(band):
191 if not added:
192 # add a patch column to fill in later
193 tab.add_column(np.full(len(tab), patch), name="patch", index=1)
194 # The id column should be objectId
195 tab.rename_column(column_id, "objectId")
196 added = True
197 else:
198 del tab[column_id]
199 patches_objid = {objid: patch for patch, objid in patches.items()}
200 patches_ref = {patch: objid for objid, patch in sorted(patches_objid.items())}
201 elif {patch: patches[patch] for patch in patches_ref.keys()} != patches_ref:
202 raise RuntimeError(f"{inputConfig=} {patches=} != {patches_ref=}")
203 else:
204 for data_patch in data_name.values():
205 for tab in data_patch.values():
206 del tab[column_id]
208 data[name] = data_name
210 self.log.info("Concatenating %s per-patch astropy Tables", len(patches))
212 tables_read = []
213 check_columns = self.config.drop_duplicate_columns or self.config.validate_duplicate_columns
214 n_bands = len(bands_sorted)
216 for name, data_name in data.items():
217 config_input = self.config.inputs[name]
218 tables = []
219 bands_missing = False
221 # If this is a multipatch dataset, loop over patches
222 # Otherwise, loop over the single "null" patch
223 for patch in patches_ref if not config_input.is_multipatch else patches_null:
224 data_name_patch = data_name[patch]
225 # If this is multiband, use the null band, and return an empty
226 # list if there's no corresponding dataset
227 if config_input.is_multiband:
228 tables_patch = data_name_patch.get(band_null, [])
229 else:
230 # Get the tables (or None if it's missing) in sorted order
231 tables_patch = [
232 _tab for band in bands_sorted if (_tab := data_name_patch.get(band)) is not None
233 ]
234 # Check if any bands are missing
235 if not bands_missing and (len(tables_patch) != n_bands):
236 bands_missing = True
237 # Join only if there's something to join
238 if tables_patch:
239 table_patch = apTab.hstack(tables_patch, join_type="exact")
240 tables.append(table_patch)
241 # If there's nothing to join, presumably the task failed
242 # stacking should handle some tasks failing but not others, but
243 # this shouldn't be relied upon
245 table_new = (
246 tables[0]
247 if (len(tables) == 1)
248 else apTab.vstack(tables, join_type="outer" if bands_missing else "exact")
249 )
251 if check_columns:
252 columns_new = set(x for x in table_new.colnames if x != config_input.join_column)
253 for name_previous in tables_read:
254 table_old = data[name_previous]
255 columns_common = columns_new.intersection(
256 x for x in table_old.colnames if x != self.config.inputs[name_previous].join_column
257 )
258 for column_common in columns_common:
259 if self.config.validate_duplicate_columns:
260 if not np.array_equal(
261 table_new[column_common],
262 table_old[column_common],
263 equal_nan=True,
264 ):
265 raise RuntimeError(
266 f"Joined table column={column_common} differs between {name} and"
267 f" {name_previous} tables"
268 )
269 if self.config.drop_duplicate_columns:
270 del table_new[column_common]
272 data[name] = table_new
273 tables_read.append(name)
275 # This will break if all tables have config.join_column
276 # ... but that seems unlikely.
277 table = apTab.hstack(
278 [data[name] for name, config in self.config.inputs.items() if config.join_column is None],
279 join_type=self.config.join_type,
280 )
281 for name, config in self.config.inputs.items():
282 if config.join_column:
283 table = apTab.join(
284 table,
285 data[name],
286 join_type=self.config.join_type,
287 keys=config.join_column,
288 )
290 butlerQC.put(pipeBase.Struct(cat_output=table), outputRefs)