Coverage for python/lsst/daf/butler/formatters/parquet.py: 13%
402 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-23 02:06 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-23 02:06 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = (
25 "ParquetFormatter",
26 "arrow_to_pandas",
27 "arrow_to_astropy",
28 "arrow_to_numpy",
29 "arrow_to_numpy_dict",
30 "pandas_to_arrow",
31 "pandas_to_astropy",
32 "astropy_to_arrow",
33 "numpy_to_arrow",
34 "numpy_to_astropy",
35 "numpy_dict_to_arrow",
36 "arrow_schema_to_pandas_index",
37 "DataFrameSchema",
38 "ArrowAstropySchema",
39 "ArrowNumpySchema",
40)
42import collections.abc
43import itertools
44import json
45import re
46from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, Union, cast
48import pyarrow as pa
49import pyarrow.parquet as pq
50from lsst.daf.butler import Formatter
51from lsst.utils.introspection import get_full_type_name
52from lsst.utils.iteration import ensure_iterable
54if TYPE_CHECKING: 54 ↛ 55line 54 didn't jump to line 55, because the condition on line 54 was never true
55 import astropy.table as atable
56 import numpy as np
57 import pandas as pd
60class ParquetFormatter(Formatter):
61 """Interface for reading and writing Arrow Table objects to and from
62 Parquet files.
63 """
65 extension = ".parq"
67 def read(self, component: Optional[str] = None) -> Any:
68 # Docstring inherited from Formatter.read.
69 schema = pq.read_schema(self.fileDescriptor.location.path)
71 if component in ("columns", "schema"):
72 # The schema will be translated to column format
73 # depending on the input type.
74 return schema
75 elif component == "rowcount":
76 # Get the rowcount from the metadata if possible, otherwise count.
77 if b"lsst::arrow::rowcount" in schema.metadata:
78 return int(schema.metadata[b"lsst::arrow::rowcount"])
80 temp_table = pq.read_table(
81 self.fileDescriptor.location.path,
82 columns=[schema.names[0]],
83 use_threads=False,
84 use_pandas_metadata=False,
85 )
87 return len(temp_table[schema.names[0]])
89 par_columns = None
90 if self.fileDescriptor.parameters:
91 par_columns = self.fileDescriptor.parameters.pop("columns", None)
92 if par_columns:
93 has_pandas_multi_index = False
94 if b"pandas" in schema.metadata:
95 md = json.loads(schema.metadata[b"pandas"])
96 if len(md["column_indexes"]) > 1:
97 has_pandas_multi_index = True
99 if not has_pandas_multi_index:
100 # Ensure uniqueness, keeping order.
101 par_columns = list(dict.fromkeys(ensure_iterable(par_columns)))
102 file_columns = [name for name in schema.names if not name.startswith("__")]
104 for par_column in par_columns:
105 if par_column not in file_columns:
106 raise ValueError(
107 f"Column {par_column} specified in parameters not available in parquet file."
108 )
109 else:
110 par_columns = _standardize_multi_index_columns(schema, par_columns)
112 if len(self.fileDescriptor.parameters):
113 raise ValueError(
114 f"Unsupported parameters {self.fileDescriptor.parameters} in ArrowTable read."
115 )
117 metadata = schema.metadata if schema.metadata is not None else {}
118 arrow_table = pq.read_table(
119 self.fileDescriptor.location.path,
120 columns=par_columns,
121 use_threads=False,
122 use_pandas_metadata=(b"pandas" in metadata),
123 )
125 return arrow_table
127 def write(self, inMemoryDataset: Any) -> None:
128 import numpy as np
129 from astropy.table import Table as astropyTable
131 arrow_table = None
132 if isinstance(inMemoryDataset, pa.Table):
133 # This will be the most likely match.
134 arrow_table = inMemoryDataset
135 elif isinstance(inMemoryDataset, astropyTable):
136 arrow_table = astropy_to_arrow(inMemoryDataset)
137 elif isinstance(inMemoryDataset, np.ndarray):
138 arrow_table = numpy_to_arrow(inMemoryDataset)
139 else:
140 if hasattr(inMemoryDataset, "to_parquet"):
141 # This may be a pandas DataFrame
142 try:
143 import pandas as pd
144 except ImportError:
145 pd = None
147 if pd is not None and isinstance(inMemoryDataset, pd.DataFrame):
148 arrow_table = pandas_to_arrow(inMemoryDataset)
150 if arrow_table is None:
151 raise ValueError(
152 f"Unsupported type {get_full_type_name(inMemoryDataset)} of "
153 "inMemoryDataset for ParquetFormatter."
154 )
156 location = self.makeUpdatedLocation(self.fileDescriptor.location)
158 pq.write_table(arrow_table, location.path)
161def arrow_to_pandas(arrow_table: pa.Table) -> pd.DataFrame:
162 """Convert a pyarrow table to a pandas DataFrame.
164 Parameters
165 ----------
166 arrow_table : `pyarrow.Table`
167 Input arrow table to convert. If the table has ``pandas`` metadata
168 in the schema it will be used in the construction of the
169 ``DataFrame``.
171 Returns
172 -------
173 dataframe : `pandas.DataFrame`
174 Converted pandas dataframe.
175 """
176 return arrow_table.to_pandas(use_threads=False, integer_object_nulls=True)
179def arrow_to_astropy(arrow_table: pa.Table) -> atable.Table:
180 """Convert a pyarrow table to an `astropy.Table`.
182 Parameters
183 ----------
184 arrow_table : `pyarrow.Table`
185 Input arrow table to convert. If the table has astropy unit
186 metadata in the schema it will be used in the construction
187 of the ``astropy.Table``.
189 Returns
190 -------
191 table : `astropy.Table`
192 Converted astropy table.
193 """
194 from astropy.table import Table
196 astropy_table = Table(arrow_to_numpy_dict(arrow_table))
198 metadata = arrow_table.schema.metadata if arrow_table.schema.metadata is not None else {}
200 _apply_astropy_metadata(astropy_table, metadata)
202 return astropy_table
205def arrow_to_numpy(arrow_table: pa.Table) -> np.ndarray:
206 """Convert a pyarrow table to a structured numpy array.
208 Parameters
209 ----------
210 arrow_table : `pyarrow.Table`
211 Input arrow table.
213 Returns
214 -------
215 array : `numpy.ndarray` (N,)
216 Numpy array table with N rows and the same column names
217 as the input arrow table.
218 """
219 import numpy as np
221 numpy_dict = arrow_to_numpy_dict(arrow_table)
223 dtype = []
224 for name, col in numpy_dict.items():
225 if len(shape := numpy_dict[name].shape) <= 1:
226 dtype.append((name, col.dtype))
227 else:
228 dtype.append((name, (col.dtype, shape[1:])))
230 array = np.rec.fromarrays(numpy_dict.values(), dtype=dtype)
232 return array
235def arrow_to_numpy_dict(arrow_table: pa.Table) -> dict[str, np.ndarray]:
236 """Convert a pyarrow table to a dict of numpy arrays.
238 Parameters
239 ----------
240 arrow_table : `pyarrow.Table`
241 Input arrow table.
243 Returns
244 -------
245 numpy_dict : `dict` [`str`, `numpy.ndarray`]
246 Dict with keys as the column names, values as the arrays.
247 """
248 import numpy as np
250 schema = arrow_table.schema
251 metadata = schema.metadata if schema.metadata is not None else {}
253 numpy_dict = {}
255 for name in schema.names:
256 t = schema.field(name).type
258 if arrow_table[name].null_count == 0:
259 # Regular non-masked column
260 col = arrow_table[name].to_numpy()
261 else:
262 # For a masked column, we need to ask arrow to fill the null
263 # values with an appropriately typed value before conversion.
264 # Then we apply the mask to get a masked array of the correct type.
266 if t in (pa.string(), pa.binary()):
267 dummy = ""
268 else:
269 dummy = t.to_pandas_dtype()(0)
271 col = np.ma.masked_array(
272 data=arrow_table[name].fill_null(dummy).to_numpy(),
273 mask=arrow_table[name].is_null().to_numpy(),
274 )
276 if t in (pa.string(), pa.binary()):
277 col = col.astype(_arrow_string_to_numpy_dtype(schema, name, col))
278 elif isinstance(t, pa.FixedSizeListType):
279 if len(col) > 0:
280 col = np.stack(col)
281 else:
282 # this is an empty column, and needs to be coerced to type.
283 col = col.astype(t.value_type.to_pandas_dtype())
285 shape = _multidim_shape_from_metadata(metadata, t.list_size, name)
286 col = col.reshape((len(arrow_table), *shape))
288 numpy_dict[name] = col
290 return numpy_dict
293def numpy_to_arrow(np_array: np.ndarray) -> pa.Table:
294 """Convert a numpy array table to an arrow table.
296 Parameters
297 ----------
298 np_array : `numpy.ndarray`
299 Input numpy array with multiple fields.
301 Returns
302 -------
303 arrow_table : `pyarrow.Table`
304 Converted arrow table.
305 """
306 type_list = _numpy_dtype_to_arrow_types(np_array.dtype)
308 md = {}
309 md[b"lsst::arrow::rowcount"] = str(len(np_array))
311 for name in np_array.dtype.names:
312 _append_numpy_string_metadata(md, name, np_array.dtype[name])
313 _append_numpy_multidim_metadata(md, name, np_array.dtype[name])
315 schema = pa.schema(type_list, metadata=md)
317 arrays = _numpy_style_arrays_to_arrow_arrays(
318 np_array.dtype,
319 len(np_array),
320 np_array,
321 schema,
322 )
324 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
326 return arrow_table
329def numpy_dict_to_arrow(numpy_dict: dict[str, np.ndarray]) -> pa.Table:
330 """Convert a dict of numpy arrays to an arrow table.
332 Parameters
333 ----------
334 numpy_dict : `dict` [`str`, `numpy.ndarray`]
335 Dict with keys as the column names, values as the arrays.
337 Returns
338 -------
339 arrow_table : `pyarrow.Table`
340 Converted arrow table.
342 Raises
343 ------
344 ValueError if columns in numpy_dict have unequal numbers of rows.
345 """
346 dtype, rowcount = _numpy_dict_to_dtype(numpy_dict)
347 type_list = _numpy_dtype_to_arrow_types(dtype)
349 md = {}
350 md[b"lsst::arrow::rowcount"] = str(rowcount)
352 if dtype.names is not None:
353 for name in dtype.names:
354 _append_numpy_string_metadata(md, name, dtype[name])
355 _append_numpy_multidim_metadata(md, name, dtype[name])
357 schema = pa.schema(type_list, metadata=md)
359 arrays = _numpy_style_arrays_to_arrow_arrays(
360 dtype,
361 rowcount,
362 numpy_dict,
363 schema,
364 )
366 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
368 return arrow_table
371def astropy_to_arrow(astropy_table: atable.Table) -> pa.Table:
372 """Convert an astropy table to an arrow table.
374 Parameters
375 ----------
376 astropy_table : `astropy.Table`
377 Input astropy table.
379 Returns
380 -------
381 arrow_table : `pyarrow.Table`
382 Converted arrow table.
383 """
384 from astropy.table import meta
386 type_list = _numpy_dtype_to_arrow_types(astropy_table.dtype)
388 md = {}
389 md[b"lsst::arrow::rowcount"] = str(len(astropy_table))
391 for name in astropy_table.dtype.names:
392 _append_numpy_string_metadata(md, name, astropy_table.dtype[name])
393 _append_numpy_multidim_metadata(md, name, astropy_table.dtype[name])
395 meta_yaml = meta.get_yaml_from_table(astropy_table)
396 meta_yaml_str = "\n".join(meta_yaml)
397 md[b"table_meta_yaml"] = meta_yaml_str
399 schema = pa.schema(type_list, metadata=md)
401 arrays = _numpy_style_arrays_to_arrow_arrays(
402 astropy_table.dtype,
403 len(astropy_table),
404 astropy_table,
405 schema,
406 )
408 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
410 return arrow_table
413def pandas_to_arrow(dataframe: pd.DataFrame, default_length: int = 10) -> pa.Table:
414 """Convert a pandas dataframe to an arrow table.
416 Parameters
417 ----------
418 dataframe : `pandas.DataFrame`
419 Input pandas dataframe.
420 default_length : `int`, optional
421 Default string length when not in metadata or can be inferred
422 from column.
424 Returns
425 -------
426 arrow_table : `pyarrow.Table`
427 Converted arrow table.
428 """
429 arrow_table = pa.Table.from_pandas(dataframe)
431 # Update the metadata
432 md = arrow_table.schema.metadata
434 md[b"lsst::arrow::rowcount"] = str(arrow_table.num_rows)
436 # We loop through the arrow table columns because the datatypes have
437 # been checked and converted from pandas objects.
438 for name in arrow_table.column_names:
439 if not name.startswith("__"):
440 if arrow_table[name].type == pa.string():
441 if len(arrow_table[name]) > 0:
442 strlen = max(len(row.as_py()) for row in arrow_table[name] if row.is_valid)
443 else:
444 strlen = default_length
445 md[f"lsst::arrow::len::{name}".encode("UTF-8")] = str(strlen)
447 arrow_table = arrow_table.replace_schema_metadata(md)
449 return arrow_table
452def pandas_to_astropy(dataframe: pd.DataFrame) -> atable.Table:
453 """Convert a pandas dataframe to an astropy table, preserving indexes.
455 Parameters
456 ----------
457 dataframe : `pandas.DataFrame`
458 Input pandas dataframe.
460 Returns
461 -------
462 astropy_table : `astropy.table.Table`
463 Converted astropy table.
464 """
465 import pandas as pd
466 from astropy.table import Table
468 if isinstance(dataframe.columns, pd.MultiIndex):
469 raise ValueError("Cannot convert a multi-index dataframe to an astropy table.")
471 return Table.from_pandas(dataframe, index=True)
474def numpy_to_astropy(np_array: np.ndarray) -> atable.Table:
475 """Convert a numpy table to an astropy table.
477 Parameters
478 ----------
479 np_array : `numpy.ndarray`
480 Input numpy array with multiple fields.
482 Returns
483 -------
484 astropy_table : `astropy.table.Table`
485 Converted astropy table.
486 """
487 from astropy.table import Table
489 return Table(data=np_array, copy=False)
492def arrow_schema_to_pandas_index(schema: pa.Schema) -> pd.Index | pd.MultiIndex:
493 """Convert an arrow schema to a pandas index/multiindex.
495 Parameters
496 ----------
497 schema : `pyarrow.Schema`
498 Input pyarrow schema.
500 Returns
501 -------
502 index : `pandas.Index` or `pandas.MultiIndex`
503 Converted pandas index.
504 """
505 import pandas as pd
507 if b"pandas" in schema.metadata:
508 md = json.loads(schema.metadata[b"pandas"])
509 indexes = md["column_indexes"]
510 len_indexes = len(indexes)
511 else:
512 len_indexes = 0
514 if len_indexes <= 1:
515 return pd.Index(name for name in schema.names if not name.startswith("__"))
516 else:
517 raw_columns = _split_multi_index_column_names(len(indexes), schema.names)
518 return pd.MultiIndex.from_tuples(raw_columns, names=[f["name"] for f in indexes])
521def arrow_schema_to_column_list(schema: pa.Schema) -> list[str]:
522 """Convert an arrow schema to a list of string column names.
524 Parameters
525 ----------
526 schema : `pyarrow.Schema`
527 Input pyarrow schema.
529 Returns
530 -------
531 column_list : `list` [`str`]
532 Converted list of column names.
533 """
534 return [name for name in schema.names]
537class DataFrameSchema:
538 """Wrapper class for a schema for a pandas DataFrame.
540 Parameters
541 ----------
542 dataframe : `pandas.DataFrame`
543 Dataframe to turn into a schema.
544 """
546 def __init__(self, dataframe: pd.DataFrame) -> None:
547 self._schema = dataframe.loc[[False] * len(dataframe)]
549 @classmethod
550 def from_arrow(cls, schema: pa.Schema) -> DataFrameSchema:
551 """Convert an arrow schema into a `DataFrameSchema`.
553 Parameters
554 ----------
555 schema : `pyarrow.Schema`
556 The pyarrow schema to convert.
558 Returns
559 -------
560 dataframe_schema : `DataFrameSchema`
561 Converted dataframe schema.
562 """
563 empty_table = pa.Table.from_pylist([] * len(schema.names), schema=schema)
565 return cls(empty_table.to_pandas())
567 def to_arrow_schema(self) -> pa.Schema:
568 """Convert to an arrow schema.
570 Returns
571 -------
572 arrow_schema : `pyarrow.Schema`
573 Converted pyarrow schema.
574 """
575 arrow_table = pa.Table.from_pandas(self._schema)
577 return arrow_table.schema
579 def to_arrow_numpy_schema(self) -> ArrowNumpySchema:
580 """Convert to an `ArrowNumpySchema`.
582 Returns
583 -------
584 arrow_numpy_schema : `ArrowNumpySchema`
585 Converted arrow numpy schema.
586 """
587 return ArrowNumpySchema.from_arrow(self.to_arrow_schema())
589 def to_arrow_astropy_schema(self) -> ArrowAstropySchema:
590 """Convert to an ArrowAstropySchema.
592 Returns
593 -------
594 arrow_astropy_schema : `ArrowAstropySchema`
595 Converted arrow astropy schema.
596 """
597 return ArrowAstropySchema.from_arrow(self.to_arrow_schema())
599 @property
600 def schema(self) -> np.dtype:
601 return self._schema
603 def __repr__(self) -> str:
604 return repr(self._schema)
606 def __eq__(self, other: object) -> bool:
607 if not isinstance(other, DataFrameSchema):
608 return NotImplemented
610 return self._schema.equals(other._schema)
613class ArrowAstropySchema:
614 """Wrapper class for a schema for an astropy table.
616 Parameters
617 ----------
618 astropy_table : `astropy.table.Table`
619 Input astropy table.
620 """
622 def __init__(self, astropy_table: atable.Table) -> None:
623 self._schema = astropy_table[:0]
625 @classmethod
626 def from_arrow(cls, schema: pa.Schema) -> ArrowAstropySchema:
627 """Convert an arrow schema into a ArrowAstropySchema.
629 Parameters
630 ----------
631 schema : `pyarrow.Schema`
632 Input pyarrow schema.
634 Returns
635 -------
636 astropy_schema : `ArrowAstropySchema`
637 Converted arrow astropy schema.
638 """
639 import numpy as np
640 from astropy.table import Table
642 dtype = _schema_to_dtype_list(schema)
644 data = np.zeros(0, dtype=dtype)
645 astropy_table = Table(data=data)
647 metadata = schema.metadata if schema.metadata is not None else {}
649 _apply_astropy_metadata(astropy_table, metadata)
651 return cls(astropy_table)
653 def to_arrow_schema(self) -> pa.Schema:
654 """Convert to an arrow schema.
656 Returns
657 -------
658 arrow_schema : `pyarrow.Schema`
659 Converted pyarrow schema.
660 """
661 return astropy_to_arrow(self._schema).schema
663 def to_dataframe_schema(self) -> DataFrameSchema:
664 """Convert to a DataFrameSchema.
666 Returns
667 -------
668 dataframe_schema : `DataFrameSchema`
669 Converted dataframe schema.
670 """
671 return DataFrameSchema.from_arrow(astropy_to_arrow(self._schema).schema)
673 def to_arrow_numpy_schema(self) -> ArrowNumpySchema:
674 """Convert to an `ArrowNumpySchema`.
676 Returns
677 -------
678 arrow_numpy_schema : `ArrowNumpySchema`
679 Converted arrow numpy schema.
680 """
681 return ArrowNumpySchema.from_arrow(astropy_to_arrow(self._schema).schema)
683 @property
684 def schema(self) -> atable.Table:
685 return self._schema
687 def __repr__(self) -> str:
688 return repr(self._schema)
690 def __eq__(self, other: object) -> bool:
691 if not isinstance(other, ArrowAstropySchema):
692 return NotImplemented
694 # If this comparison passes then the two tables have the
695 # same column names.
696 if self._schema.dtype != other._schema.dtype:
697 return False
699 for name in self._schema.columns:
700 if not self._schema[name].unit == other._schema[name].unit:
701 return False
702 if not self._schema[name].description == other._schema[name].description:
703 return False
704 if not self._schema[name].format == other._schema[name].format:
705 return False
707 return True
710class ArrowNumpySchema:
711 """Wrapper class for a schema for a numpy ndarray.
713 Parameters
714 ----------
715 numpy_dtype : `numpy.dtype`
716 Numpy dtype to convert.
717 """
719 def __init__(self, numpy_dtype: np.dtype) -> None:
720 self._dtype = numpy_dtype
722 @classmethod
723 def from_arrow(cls, schema: pa.Schema) -> ArrowNumpySchema:
724 """Convert an arrow schema into an `ArrowNumpySchema`.
726 Parameters
727 ----------
728 schema : `pyarrow.Schema`
729 Pyarrow schema to convert.
731 Returns
732 -------
733 numpy_schema : `ArrowNumpySchema`
734 Converted arrow numpy schema.
735 """
736 import numpy as np
738 dtype = _schema_to_dtype_list(schema)
740 return cls(np.dtype(dtype))
742 def to_arrow_astropy_schema(self) -> ArrowAstropySchema:
743 """Convert to an `ArrowAstropySchema`.
745 Returns
746 -------
747 astropy_schema : `ArrowAstropySchema`
748 Converted arrow astropy schema.
749 """
750 import numpy as np
752 return ArrowAstropySchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema)
754 def to_dataframe_schema(self) -> DataFrameSchema:
755 """Convert to a `DataFrameSchema`.
757 Returns
758 -------
759 dataframe_schema : `DataFrameSchema`
760 Converted dataframe schema.
761 """
762 import numpy as np
764 return DataFrameSchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema)
766 def to_arrow_schema(self) -> pa.Schema:
767 """Convert to a `pyarrow.Schema`.
769 Returns
770 -------
771 arrow_schema : `pyarrow.Schema`
772 Converted pyarrow schema.
773 """
774 import numpy as np
776 return numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema
778 @property
779 def schema(self) -> np.dtype:
780 return self._dtype
782 def __repr__(self) -> str:
783 return repr(self._dtype)
785 def __eq__(self, other: object) -> bool:
786 if not isinstance(other, ArrowNumpySchema):
787 return NotImplemented
789 if not self._dtype == other._dtype:
790 return False
792 return True
795def _split_multi_index_column_names(n: int, names: Iterable[str]) -> List[Sequence[str]]:
796 """Split a string that represents a multi-index column.
798 PyArrow maps Pandas' multi-index column names (which are tuples in Python)
799 to flat strings on disk. This routine exists to reconstruct the original
800 tuple.
802 Parameters
803 ----------
804 n : `int`
805 Number of levels in the `pandas.MultiIndex` that is being
806 reconstructed.
807 names : `~collections.abc.Iterable` [`str`]
808 Strings to be split.
810 Returns
811 -------
812 column_names : `list` [`tuple` [`str`]]
813 A list of multi-index column name tuples.
814 """
815 column_names: List[Sequence[str]] = []
817 pattern = re.compile(r"\({}\)".format(", ".join(["'(.*)'"] * n)))
818 for name in names:
819 m = re.search(pattern, name)
820 if m is not None:
821 column_names.append(m.groups())
823 return column_names
826def _standardize_multi_index_columns(
827 schema: pa.Schema, columns: Union[List[tuple], dict[str, Union[str, List[str]]]]
828) -> List[str]:
829 """Transform a dictionary/iterable index from a multi-index column list
830 into a string directly understandable by PyArrow.
832 Parameters
833 ----------
834 schema : `pyarrow.Schema`
835 Pyarrow schema.
836 columns : `list` [`tuple`] or `dict` [`str`, `str` or `list` [`str`]]
837 Columns to standardize.
839 Returns
840 -------
841 names : `list` [`str`]
842 Stringified representation of a multi-index column name.
843 """
844 pd_index = arrow_schema_to_pandas_index(schema)
845 index_level_names = tuple(pd_index.names)
847 names = []
849 if isinstance(columns, list):
850 for requested in columns:
851 if not isinstance(requested, tuple):
852 raise ValueError(
853 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. "
854 f"Instead got a {get_full_type_name(requested)}."
855 )
856 names.append(str(requested))
857 else:
858 if not isinstance(columns, collections.abc.Mapping):
859 raise ValueError(
860 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. "
861 f"Instead got a {get_full_type_name(columns)}."
862 )
863 if not set(index_level_names).issuperset(columns.keys()):
864 raise ValueError(
865 f"Cannot use dict with keys {set(columns.keys())} to select columns from {index_level_names}."
866 )
867 factors = [
868 ensure_iterable(columns.get(level, pd_index.levels[i]))
869 for i, level in enumerate(index_level_names)
870 ]
871 for requested in itertools.product(*factors):
872 for i, value in enumerate(requested):
873 if value not in pd_index.levels[i]:
874 raise ValueError(f"Unrecognized value {value!r} for index {index_level_names[i]!r}.")
875 names.append(str(requested))
877 return names
880def _apply_astropy_metadata(astropy_table: atable.Table, metadata: dict) -> None:
881 """Apply any astropy metadata from the schema metadata.
883 Parameters
884 ----------
885 astropy_table : `astropy.table.Table`
886 Table to apply metadata.
887 metadata : `dict` [`bytes`]
888 Metadata dict.
889 """
890 from astropy.table import meta
892 meta_yaml = metadata.get(b"table_meta_yaml", None)
893 if meta_yaml:
894 meta_yaml = meta_yaml.decode("UTF8").split("\n")
895 meta_hdr = meta.get_header_from_yaml(meta_yaml)
897 # Set description, format, unit, meta from the column
898 # metadata that was serialized with the table.
899 header_cols = {x["name"]: x for x in meta_hdr["datatype"]}
900 for col in astropy_table.columns.values():
901 for attr in ("description", "format", "unit", "meta"):
902 if attr in header_cols[col.name]:
903 setattr(col, attr, header_cols[col.name][attr])
905 if "meta" in meta_hdr:
906 astropy_table.meta.update(meta_hdr["meta"])
909def _arrow_string_to_numpy_dtype(
910 schema: pa.Schema, name: str, numpy_column: np.ndarray | None = None, default_length: int = 10
911) -> str:
912 """Get the numpy dtype string associated with an arrow column.
914 Parameters
915 ----------
916 schema : `pyarrow.Schema`
917 Arrow table schema.
918 name : `str`
919 Column name.
920 numpy_column : `numpy.ndarray`, optional
921 Column to determine numpy string dtype.
922 default_length : `int`, optional
923 Default string length when not in metadata or can be inferred
924 from column.
926 Returns
927 -------
928 dtype_str : `str`
929 Numpy dtype string.
930 """
931 # Special-case for string and binary columns
932 md_name = f"lsst::arrow::len::{name}"
933 strlen = default_length
934 metadata = schema.metadata if schema.metadata is not None else {}
935 if (encoded := md_name.encode("UTF-8")) in metadata:
936 # String/bytes length from header.
937 strlen = int(schema.metadata[encoded])
938 elif numpy_column is not None:
939 if len(numpy_column) > 0:
940 strlen = max(len(row) for row in numpy_column)
942 dtype = f"U{strlen}" if schema.field(name).type == pa.string() else f"|S{strlen}"
944 return dtype
947def _append_numpy_string_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None:
948 """Append numpy string length keys to arrow metadata.
950 All column types are handled, but the metadata is only modified for
951 string and byte columns.
953 Parameters
954 ----------
955 metadata : `dict` [`bytes`, `str`]
956 Metadata dictionary; modified in place.
957 name : `str`
958 Column name.
959 dtype : `np.dtype`
960 Numpy dtype.
961 """
962 import numpy as np
964 if dtype.type is np.str_:
965 metadata[f"lsst::arrow::len::{name}".encode("UTF-8")] = str(dtype.itemsize // 4)
966 elif dtype.type is np.bytes_:
967 metadata[f"lsst::arrow::len::{name}".encode("UTF-8")] = str(dtype.itemsize)
970def _append_numpy_multidim_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None:
971 """Append numpy multi-dimensional shapes to arrow metadata.
973 All column types are handled, but the metadata is only modified for
974 multi-dimensional columns.
976 Parameters
977 ----------
978 metadata : `dict` [`bytes`, `str`]
979 Metadata dictionary; modified in place.
980 name : `str`
981 Column name.
982 dtype : `np.dtype`
983 Numpy dtype.
984 """
985 if len(dtype.shape) > 1:
986 metadata[f"lsst::arrow::shape::{name}".encode("UTF-8")] = str(dtype.shape)
989def _multidim_shape_from_metadata(metadata: dict[bytes, bytes], list_size: int, name: str) -> tuple[int, ...]:
990 """Retrieve the shape from the metadata, if available.
992 Parameters
993 ----------
994 metadata : `dict` [`bytes`, `bytes`]
995 Metadata dictionary.
996 list_size : `int`
997 Size of the list datatype.
998 name : `str`
999 Column name.
1001 Returns
1002 -------
1003 shape : `tuple` [`int`]
1004 Shape associated with the column.
1006 Raises
1007 ------
1008 RuntimeError
1009 Raised if metadata is found but has incorrect format.
1010 """
1011 md_name = f"lsst::arrow::shape::{name}"
1012 if (encoded := md_name.encode("UTF-8")) in metadata:
1013 groups = re.search(r"\((.*)\)", metadata[encoded].decode("UTF-8"))
1014 if groups is None:
1015 raise RuntimeError("Illegal value found in metadata.")
1016 shape = tuple(int(x) for x in groups[1].split(",") if x != "")
1017 else:
1018 shape = (list_size,)
1020 return shape
1023def _schema_to_dtype_list(schema: pa.Schema) -> list[tuple[str, tuple[Any] | str]]:
1024 """Convert a pyarrow schema to a numpy dtype.
1026 Parameters
1027 ----------
1028 schema : `pyarrow.Schema`
1029 Input pyarrow schema.
1031 Returns
1032 -------
1033 dtype_list: `list` [`tuple`]
1034 A list with name, type pairs.
1035 """
1036 metadata = schema.metadata if schema.metadata is not None else {}
1038 dtype: list[Any] = []
1039 for name in schema.names:
1040 t = schema.field(name).type
1041 if isinstance(t, pa.FixedSizeListType):
1042 shape = _multidim_shape_from_metadata(metadata, t.list_size, name)
1043 dtype.append((name, (t.value_type.to_pandas_dtype(), shape)))
1044 elif t not in (pa.string(), pa.binary()):
1045 dtype.append((name, t.to_pandas_dtype()))
1046 else:
1047 dtype.append((name, _arrow_string_to_numpy_dtype(schema, name)))
1049 return dtype
1052def _numpy_dtype_to_arrow_types(dtype: np.dtype) -> list[Any]:
1053 """Convert a numpy dtype to a list of arrow types.
1055 Parameters
1056 ----------
1057 dtype : `numpy.dtype`
1058 Numpy dtype to convert.
1060 Returns
1061 -------
1062 type_list : `list` [`object`]
1063 Converted list of arrow types.
1064 """
1065 from math import prod
1067 import numpy as np
1069 type_list: list[Any] = []
1070 if dtype.names is None:
1071 return type_list
1073 for name in dtype.names:
1074 dt = dtype[name]
1075 arrow_type: Any
1076 if len(dt.shape) > 0:
1077 arrow_type = pa.list_(
1078 pa.from_numpy_dtype(cast(tuple[np.dtype, tuple[int, ...]], dt.subdtype)[0].type),
1079 prod(dt.shape),
1080 )
1081 else:
1082 arrow_type = pa.from_numpy_dtype(dt.type)
1083 type_list.append((name, arrow_type))
1085 return type_list
1088def _numpy_dict_to_dtype(numpy_dict: dict[str, np.ndarray]) -> tuple[np.dtype, int]:
1089 """Extract equivalent table dtype from dict of numpy arrays.
1091 Parameters
1092 ----------
1093 numpy_dict : `dict` [`str`, `numpy.ndarray`]
1094 Dict with keys as the column names, values as the arrays.
1096 Returns
1097 -------
1098 dtype : `numpy.dtype`
1099 dtype of equivalent table.
1100 rowcount : `int`
1101 Number of rows in the table.
1103 Raises
1104 ------
1105 ValueError if columns in numpy_dict have unequal numbers of rows.
1106 """
1107 import numpy as np
1109 dtype_list = []
1110 rowcount = 0
1111 for name, col in numpy_dict.items():
1112 if rowcount == 0:
1113 rowcount = len(col)
1114 if len(col) != rowcount:
1115 raise ValueError(f"Column {name} has a different number of rows.")
1116 if len(col.shape) == 1:
1117 dtype_list.append((name, col.dtype))
1118 else:
1119 dtype_list.append((name, (col.dtype, col.shape[1:])))
1120 dtype = np.dtype(dtype_list)
1122 return (dtype, rowcount)
1125def _numpy_style_arrays_to_arrow_arrays(
1126 dtype: np.dtype,
1127 rowcount: int,
1128 np_style_arrays: dict[str, np.ndarray] | np.ndarray | atable.Table,
1129 schema: pa.Schema,
1130) -> list[pa.Array]:
1131 """Convert numpy-style arrays to arrow arrays.
1133 Parameters
1134 ----------
1135 dtype : `numpy.dtype`
1136 Numpy dtype of input table/arrays.
1137 rowcount : `int`
1138 Number of rows in input table/arrays.
1139 np_style_arrays : `dict` [`str`, `np.ndarray`] or `np.ndarray`
1140 or `astropy.table.Table`
1141 Arrays to convert to arrow.
1142 schema : `pyarrow.Schema`
1143 Schema of arrow table.
1145 Returns
1146 -------
1147 arrow_arrays : `list` [`pyarrow.Array`]
1148 List of converted pyarrow arrays.
1149 """
1150 import numpy as np
1152 arrow_arrays: list[pa.Array] = []
1153 if dtype.names is None:
1154 return arrow_arrays
1156 for name in dtype.names:
1157 dt = dtype[name]
1158 val: Any
1159 if len(dt.shape) > 0:
1160 if rowcount > 0:
1161 val = np.split(np_style_arrays[name].ravel(), rowcount)
1162 else:
1163 val = []
1164 else:
1165 val = np_style_arrays[name]
1167 try:
1168 arrow_arrays.append(pa.array(val, type=schema.field(name).type))
1169 except pa.ArrowNotImplementedError as err:
1170 # Check if val is big-endian.
1171 if (np.little_endian and val.dtype.byteorder == ">") or (
1172 not np.little_endian and val.dtype.byteorder == "="
1173 ):
1174 # We need to convert the array to little-endian.
1175 val2 = val.byteswap()
1176 val2.dtype = val2.dtype.newbyteorder("<")
1177 arrow_arrays.append(pa.array(val2, type=schema.field(name).type))
1178 else:
1179 # This failed for some other reason so raise the exception.
1180 raise err
1182 return arrow_arrays