Coverage for python/lsst/daf/butler/formatters/parquet.py: 14%
442 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-08-05 01:26 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-08-05 01:26 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = (
25 "ParquetFormatter",
26 "arrow_to_pandas",
27 "arrow_to_astropy",
28 "arrow_to_numpy",
29 "arrow_to_numpy_dict",
30 "pandas_to_arrow",
31 "pandas_to_astropy",
32 "astropy_to_arrow",
33 "numpy_to_arrow",
34 "numpy_to_astropy",
35 "numpy_dict_to_arrow",
36 "arrow_schema_to_pandas_index",
37 "DataFrameSchema",
38 "ArrowAstropySchema",
39 "ArrowNumpySchema",
40 "compute_row_group_size",
41)
43import collections.abc
44import itertools
45import json
46import re
47from collections.abc import Iterable, Sequence
48from typing import TYPE_CHECKING, Any, cast
50import pyarrow as pa
51import pyarrow.parquet as pq
52from lsst.daf.butler import Formatter
53from lsst.utils.introspection import get_full_type_name
54from lsst.utils.iteration import ensure_iterable
56if TYPE_CHECKING:
57 import astropy.table as atable
58 import numpy as np
59 import pandas as pd
61TARGET_ROW_GROUP_BYTES = 1_000_000_000
64class ParquetFormatter(Formatter):
65 """Interface for reading and writing Arrow Table objects to and from
66 Parquet files.
67 """
69 extension = ".parq"
71 def read(self, component: str | None = None) -> Any:
72 # Docstring inherited from Formatter.read.
73 schema = pq.read_schema(self.fileDescriptor.location.path)
75 if component in ("columns", "schema"):
76 # The schema will be translated to column format
77 # depending on the input type.
78 return schema
79 elif component == "rowcount":
80 # Get the rowcount from the metadata if possible, otherwise count.
81 if b"lsst::arrow::rowcount" in schema.metadata:
82 return int(schema.metadata[b"lsst::arrow::rowcount"])
84 temp_table = pq.read_table(
85 self.fileDescriptor.location.path,
86 columns=[schema.names[0]],
87 use_threads=False,
88 use_pandas_metadata=False,
89 )
91 return len(temp_table[schema.names[0]])
93 par_columns = None
94 if self.fileDescriptor.parameters:
95 par_columns = self.fileDescriptor.parameters.pop("columns", None)
96 if par_columns:
97 has_pandas_multi_index = False
98 if b"pandas" in schema.metadata:
99 md = json.loads(schema.metadata[b"pandas"])
100 if len(md["column_indexes"]) > 1:
101 has_pandas_multi_index = True
103 if not has_pandas_multi_index:
104 # Ensure uniqueness, keeping order.
105 par_columns = list(dict.fromkeys(ensure_iterable(par_columns)))
106 file_columns = [name for name in schema.names if not name.startswith("__")]
108 for par_column in par_columns:
109 if par_column not in file_columns:
110 raise ValueError(
111 f"Column {par_column} specified in parameters not available in parquet file."
112 )
113 else:
114 par_columns = _standardize_multi_index_columns(
115 arrow_schema_to_pandas_index(schema),
116 par_columns,
117 )
119 if len(self.fileDescriptor.parameters):
120 raise ValueError(
121 f"Unsupported parameters {self.fileDescriptor.parameters} in ArrowTable read."
122 )
124 metadata = schema.metadata if schema.metadata is not None else {}
125 arrow_table = pq.read_table(
126 self.fileDescriptor.location.path,
127 columns=par_columns,
128 use_threads=False,
129 use_pandas_metadata=(b"pandas" in metadata),
130 )
132 return arrow_table
134 def write(self, inMemoryDataset: Any) -> None:
135 import numpy as np
136 from astropy.table import Table as astropyTable
138 arrow_table = None
139 if isinstance(inMemoryDataset, pa.Table):
140 # This will be the most likely match.
141 arrow_table = inMemoryDataset
142 elif isinstance(inMemoryDataset, astropyTable):
143 arrow_table = astropy_to_arrow(inMemoryDataset)
144 elif isinstance(inMemoryDataset, np.ndarray):
145 arrow_table = numpy_to_arrow(inMemoryDataset)
146 elif isinstance(inMemoryDataset, dict):
147 try:
148 arrow_table = numpy_dict_to_arrow(inMemoryDataset)
149 except (TypeError, AttributeError) as e:
150 raise ValueError(
151 "Input dict for inMemoryDataset does not appear to be a dict of numpy arrays."
152 ) from e
153 else:
154 if hasattr(inMemoryDataset, "to_parquet"):
155 # This may be a pandas DataFrame
156 try:
157 import pandas as pd
158 except ImportError:
159 pd = None
161 if pd is not None and isinstance(inMemoryDataset, pd.DataFrame):
162 arrow_table = pandas_to_arrow(inMemoryDataset)
164 if arrow_table is None:
165 raise ValueError(
166 f"Unsupported type {get_full_type_name(inMemoryDataset)} of "
167 "inMemoryDataset for ParquetFormatter."
168 )
170 row_group_size = compute_row_group_size(arrow_table.schema)
172 location = self.makeUpdatedLocation(self.fileDescriptor.location)
174 pq.write_table(arrow_table, location.path, row_group_size=row_group_size)
177def arrow_to_pandas(arrow_table: pa.Table) -> pd.DataFrame:
178 """Convert a pyarrow table to a pandas DataFrame.
180 Parameters
181 ----------
182 arrow_table : `pyarrow.Table`
183 Input arrow table to convert. If the table has ``pandas`` metadata
184 in the schema it will be used in the construction of the
185 ``DataFrame``.
187 Returns
188 -------
189 dataframe : `pandas.DataFrame`
190 Converted pandas dataframe.
191 """
192 return arrow_table.to_pandas(use_threads=False, integer_object_nulls=True)
195def arrow_to_astropy(arrow_table: pa.Table) -> atable.Table:
196 """Convert a pyarrow table to an `astropy.Table`.
198 Parameters
199 ----------
200 arrow_table : `pyarrow.Table`
201 Input arrow table to convert. If the table has astropy unit
202 metadata in the schema it will be used in the construction
203 of the ``astropy.Table``.
205 Returns
206 -------
207 table : `astropy.Table`
208 Converted astropy table.
209 """
210 from astropy.table import Table
212 astropy_table = Table(arrow_to_numpy_dict(arrow_table))
214 metadata = arrow_table.schema.metadata if arrow_table.schema.metadata is not None else {}
216 _apply_astropy_metadata(astropy_table, metadata)
218 return astropy_table
221def arrow_to_numpy(arrow_table: pa.Table) -> np.ndarray:
222 """Convert a pyarrow table to a structured numpy array.
224 Parameters
225 ----------
226 arrow_table : `pyarrow.Table`
227 Input arrow table.
229 Returns
230 -------
231 array : `numpy.ndarray` (N,)
232 Numpy array table with N rows and the same column names
233 as the input arrow table.
234 """
235 import numpy as np
237 numpy_dict = arrow_to_numpy_dict(arrow_table)
239 dtype = []
240 for name, col in numpy_dict.items():
241 if len(shape := numpy_dict[name].shape) <= 1:
242 dtype.append((name, col.dtype))
243 else:
244 dtype.append((name, (col.dtype, shape[1:])))
246 array = np.rec.fromarrays(numpy_dict.values(), dtype=dtype)
248 return array
251def arrow_to_numpy_dict(arrow_table: pa.Table) -> dict[str, np.ndarray]:
252 """Convert a pyarrow table to a dict of numpy arrays.
254 Parameters
255 ----------
256 arrow_table : `pyarrow.Table`
257 Input arrow table.
259 Returns
260 -------
261 numpy_dict : `dict` [`str`, `numpy.ndarray`]
262 Dict with keys as the column names, values as the arrays.
263 """
264 import numpy as np
266 schema = arrow_table.schema
267 metadata = schema.metadata if schema.metadata is not None else {}
269 numpy_dict = {}
271 for name in schema.names:
272 t = schema.field(name).type
274 if arrow_table[name].null_count == 0:
275 # Regular non-masked column
276 col = arrow_table[name].to_numpy()
277 else:
278 # For a masked column, we need to ask arrow to fill the null
279 # values with an appropriately typed value before conversion.
280 # Then we apply the mask to get a masked array of the correct type.
282 if t in (pa.string(), pa.binary()):
283 dummy = ""
284 else:
285 dummy = t.to_pandas_dtype()(0)
287 col = np.ma.masked_array(
288 data=arrow_table[name].fill_null(dummy).to_numpy(),
289 mask=arrow_table[name].is_null().to_numpy(),
290 )
292 if t in (pa.string(), pa.binary()):
293 col = col.astype(_arrow_string_to_numpy_dtype(schema, name, col))
294 elif isinstance(t, pa.FixedSizeListType):
295 if len(col) > 0:
296 col = np.stack(col)
297 else:
298 # this is an empty column, and needs to be coerced to type.
299 col = col.astype(t.value_type.to_pandas_dtype())
301 shape = _multidim_shape_from_metadata(metadata, t.list_size, name)
302 col = col.reshape((len(arrow_table), *shape))
304 numpy_dict[name] = col
306 return numpy_dict
309def _numpy_dict_to_numpy(numpy_dict: dict[str, np.ndarray]) -> np.ndarray:
310 """Convert a dict of numpy arrays to a structured numpy array.
312 Parameters
313 ----------
314 numpy_dict : `dict` [`str`, `numpy.ndarray`]
315 Dict with keys as the column names, values as the arrays.
317 Returns
318 -------
319 array : `numpy.ndarray` (N,)
320 Numpy array table with N rows and columns names from the dict keys.
321 """
322 return arrow_to_numpy(numpy_dict_to_arrow(numpy_dict))
325def _numpy_to_numpy_dict(np_array: np.ndarray) -> dict[str, np.ndarray]:
326 """Convert a structured numpy array to a dict of numpy arrays.
328 Parameters
329 ----------
330 np_array : `numpy.ndarray`
331 Input numpy array with multiple fields.
333 Returns
334 -------
335 numpy_dict : `dict` [`str`, `numpy.ndarray`]
336 Dict with keys as the column names, values as the arrays.
337 """
338 return arrow_to_numpy_dict(numpy_to_arrow(np_array))
341def numpy_to_arrow(np_array: np.ndarray) -> pa.Table:
342 """Convert a numpy array table to an arrow table.
344 Parameters
345 ----------
346 np_array : `numpy.ndarray`
347 Input numpy array with multiple fields.
349 Returns
350 -------
351 arrow_table : `pyarrow.Table`
352 Converted arrow table.
353 """
354 type_list = _numpy_dtype_to_arrow_types(np_array.dtype)
356 md = {}
357 md[b"lsst::arrow::rowcount"] = str(len(np_array))
359 for name in np_array.dtype.names:
360 _append_numpy_string_metadata(md, name, np_array.dtype[name])
361 _append_numpy_multidim_metadata(md, name, np_array.dtype[name])
363 schema = pa.schema(type_list, metadata=md)
365 arrays = _numpy_style_arrays_to_arrow_arrays(
366 np_array.dtype,
367 len(np_array),
368 np_array,
369 schema,
370 )
372 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
374 return arrow_table
377def numpy_dict_to_arrow(numpy_dict: dict[str, np.ndarray]) -> pa.Table:
378 """Convert a dict of numpy arrays to an arrow table.
380 Parameters
381 ----------
382 numpy_dict : `dict` [`str`, `numpy.ndarray`]
383 Dict with keys as the column names, values as the arrays.
385 Returns
386 -------
387 arrow_table : `pyarrow.Table`
388 Converted arrow table.
390 Raises
391 ------
392 ValueError if columns in numpy_dict have unequal numbers of rows.
393 """
394 dtype, rowcount = _numpy_dict_to_dtype(numpy_dict)
395 type_list = _numpy_dtype_to_arrow_types(dtype)
397 md = {}
398 md[b"lsst::arrow::rowcount"] = str(rowcount)
400 if dtype.names is not None:
401 for name in dtype.names:
402 _append_numpy_string_metadata(md, name, dtype[name])
403 _append_numpy_multidim_metadata(md, name, dtype[name])
405 schema = pa.schema(type_list, metadata=md)
407 arrays = _numpy_style_arrays_to_arrow_arrays(
408 dtype,
409 rowcount,
410 numpy_dict,
411 schema,
412 )
414 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
416 return arrow_table
419def astropy_to_arrow(astropy_table: atable.Table) -> pa.Table:
420 """Convert an astropy table to an arrow table.
422 Parameters
423 ----------
424 astropy_table : `astropy.Table`
425 Input astropy table.
427 Returns
428 -------
429 arrow_table : `pyarrow.Table`
430 Converted arrow table.
431 """
432 from astropy.table import meta
434 type_list = _numpy_dtype_to_arrow_types(astropy_table.dtype)
436 md = {}
437 md[b"lsst::arrow::rowcount"] = str(len(astropy_table))
439 for name in astropy_table.dtype.names:
440 _append_numpy_string_metadata(md, name, astropy_table.dtype[name])
441 _append_numpy_multidim_metadata(md, name, astropy_table.dtype[name])
443 meta_yaml = meta.get_yaml_from_table(astropy_table)
444 meta_yaml_str = "\n".join(meta_yaml)
445 md[b"table_meta_yaml"] = meta_yaml_str
447 schema = pa.schema(type_list, metadata=md)
449 arrays = _numpy_style_arrays_to_arrow_arrays(
450 astropy_table.dtype,
451 len(astropy_table),
452 astropy_table,
453 schema,
454 )
456 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
458 return arrow_table
461def _astropy_to_numpy_dict(astropy_table: atable.Table) -> dict[str, np.ndarray]:
462 """Convert an astropy table to an arrow table.
464 Parameters
465 ----------
466 astropy_table : `astropy.Table`
467 Input astropy table.
469 Returns
470 -------
471 numpy_dict : `dict` [`str`, `numpy.ndarray`]
472 Dict with keys as the column names, values as the arrays.
473 """
474 return arrow_to_numpy_dict(astropy_to_arrow(astropy_table))
477def pandas_to_arrow(dataframe: pd.DataFrame, default_length: int = 10) -> pa.Table:
478 """Convert a pandas dataframe to an arrow table.
480 Parameters
481 ----------
482 dataframe : `pandas.DataFrame`
483 Input pandas dataframe.
484 default_length : `int`, optional
485 Default string length when not in metadata or can be inferred
486 from column.
488 Returns
489 -------
490 arrow_table : `pyarrow.Table`
491 Converted arrow table.
492 """
493 arrow_table = pa.Table.from_pandas(dataframe)
495 # Update the metadata
496 md = arrow_table.schema.metadata
498 md[b"lsst::arrow::rowcount"] = str(arrow_table.num_rows)
500 # We loop through the arrow table columns because the datatypes have
501 # been checked and converted from pandas objects.
502 for name in arrow_table.column_names:
503 if not name.startswith("__") and arrow_table[name].type == pa.string():
504 if len(arrow_table[name]) > 0:
505 strlen = max(len(row.as_py()) for row in arrow_table[name] if row.is_valid)
506 else:
507 strlen = default_length
508 md[f"lsst::arrow::len::{name}".encode()] = str(strlen)
510 arrow_table = arrow_table.replace_schema_metadata(md)
512 return arrow_table
515def pandas_to_astropy(dataframe: pd.DataFrame) -> atable.Table:
516 """Convert a pandas dataframe to an astropy table, preserving indexes.
518 Parameters
519 ----------
520 dataframe : `pandas.DataFrame`
521 Input pandas dataframe.
523 Returns
524 -------
525 astropy_table : `astropy.table.Table`
526 Converted astropy table.
527 """
528 import pandas as pd
529 from astropy.table import Table
531 if isinstance(dataframe.columns, pd.MultiIndex):
532 raise ValueError("Cannot convert a multi-index dataframe to an astropy table.")
534 return Table.from_pandas(dataframe, index=True)
537def _pandas_to_numpy_dict(dataframe: pd.DataFrame) -> dict[str, np.ndarray]:
538 """Convert a pandas dataframe to an dict of numpy arrays.
540 Parameters
541 ----------
542 dataframe : `pandas.DataFrame`
543 Input pandas dataframe.
545 Returns
546 -------
547 numpy_dict : `dict` [`str`, `numpy.ndarray`]
548 Dict with keys as the column names, values as the arrays.
549 """
550 return arrow_to_numpy_dict(pandas_to_arrow(dataframe))
553def numpy_to_astropy(np_array: np.ndarray) -> atable.Table:
554 """Convert a numpy table to an astropy table.
556 Parameters
557 ----------
558 np_array : `numpy.ndarray`
559 Input numpy array with multiple fields.
561 Returns
562 -------
563 astropy_table : `astropy.table.Table`
564 Converted astropy table.
565 """
566 from astropy.table import Table
568 return Table(data=np_array, copy=False)
571def arrow_schema_to_pandas_index(schema: pa.Schema) -> pd.Index | pd.MultiIndex:
572 """Convert an arrow schema to a pandas index/multiindex.
574 Parameters
575 ----------
576 schema : `pyarrow.Schema`
577 Input pyarrow schema.
579 Returns
580 -------
581 index : `pandas.Index` or `pandas.MultiIndex`
582 Converted pandas index.
583 """
584 import pandas as pd
586 if b"pandas" in schema.metadata:
587 md = json.loads(schema.metadata[b"pandas"])
588 indexes = md["column_indexes"]
589 len_indexes = len(indexes)
590 else:
591 len_indexes = 0
593 if len_indexes <= 1:
594 return pd.Index(name for name in schema.names if not name.startswith("__"))
595 else:
596 raw_columns = _split_multi_index_column_names(len(indexes), schema.names)
597 return pd.MultiIndex.from_tuples(raw_columns, names=[f["name"] for f in indexes])
600def arrow_schema_to_column_list(schema: pa.Schema) -> list[str]:
601 """Convert an arrow schema to a list of string column names.
603 Parameters
604 ----------
605 schema : `pyarrow.Schema`
606 Input pyarrow schema.
608 Returns
609 -------
610 column_list : `list` [`str`]
611 Converted list of column names.
612 """
613 return [name for name in schema.names]
616class DataFrameSchema:
617 """Wrapper class for a schema for a pandas DataFrame.
619 Parameters
620 ----------
621 dataframe : `pandas.DataFrame`
622 Dataframe to turn into a schema.
623 """
625 def __init__(self, dataframe: pd.DataFrame) -> None:
626 self._schema = dataframe.loc[[False] * len(dataframe)]
628 @classmethod
629 def from_arrow(cls, schema: pa.Schema) -> DataFrameSchema:
630 """Convert an arrow schema into a `DataFrameSchema`.
632 Parameters
633 ----------
634 schema : `pyarrow.Schema`
635 The pyarrow schema to convert.
637 Returns
638 -------
639 dataframe_schema : `DataFrameSchema`
640 Converted dataframe schema.
641 """
642 empty_table = pa.Table.from_pylist([] * len(schema.names), schema=schema)
644 return cls(empty_table.to_pandas())
646 def to_arrow_schema(self) -> pa.Schema:
647 """Convert to an arrow schema.
649 Returns
650 -------
651 arrow_schema : `pyarrow.Schema`
652 Converted pyarrow schema.
653 """
654 arrow_table = pa.Table.from_pandas(self._schema)
656 return arrow_table.schema
658 def to_arrow_numpy_schema(self) -> ArrowNumpySchema:
659 """Convert to an `ArrowNumpySchema`.
661 Returns
662 -------
663 arrow_numpy_schema : `ArrowNumpySchema`
664 Converted arrow numpy schema.
665 """
666 return ArrowNumpySchema.from_arrow(self.to_arrow_schema())
668 def to_arrow_astropy_schema(self) -> ArrowAstropySchema:
669 """Convert to an ArrowAstropySchema.
671 Returns
672 -------
673 arrow_astropy_schema : `ArrowAstropySchema`
674 Converted arrow astropy schema.
675 """
676 return ArrowAstropySchema.from_arrow(self.to_arrow_schema())
678 @property
679 def schema(self) -> np.dtype:
680 return self._schema
682 def __repr__(self) -> str:
683 return repr(self._schema)
685 def __eq__(self, other: object) -> bool:
686 if not isinstance(other, DataFrameSchema):
687 return NotImplemented
689 return self._schema.equals(other._schema)
692class ArrowAstropySchema:
693 """Wrapper class for a schema for an astropy table.
695 Parameters
696 ----------
697 astropy_table : `astropy.table.Table`
698 Input astropy table.
699 """
701 def __init__(self, astropy_table: atable.Table) -> None:
702 self._schema = astropy_table[:0]
704 @classmethod
705 def from_arrow(cls, schema: pa.Schema) -> ArrowAstropySchema:
706 """Convert an arrow schema into a ArrowAstropySchema.
708 Parameters
709 ----------
710 schema : `pyarrow.Schema`
711 Input pyarrow schema.
713 Returns
714 -------
715 astropy_schema : `ArrowAstropySchema`
716 Converted arrow astropy schema.
717 """
718 import numpy as np
719 from astropy.table import Table
721 dtype = _schema_to_dtype_list(schema)
723 data = np.zeros(0, dtype=dtype)
724 astropy_table = Table(data=data)
726 metadata = schema.metadata if schema.metadata is not None else {}
728 _apply_astropy_metadata(astropy_table, metadata)
730 return cls(astropy_table)
732 def to_arrow_schema(self) -> pa.Schema:
733 """Convert to an arrow schema.
735 Returns
736 -------
737 arrow_schema : `pyarrow.Schema`
738 Converted pyarrow schema.
739 """
740 return astropy_to_arrow(self._schema).schema
742 def to_dataframe_schema(self) -> DataFrameSchema:
743 """Convert to a DataFrameSchema.
745 Returns
746 -------
747 dataframe_schema : `DataFrameSchema`
748 Converted dataframe schema.
749 """
750 return DataFrameSchema.from_arrow(astropy_to_arrow(self._schema).schema)
752 def to_arrow_numpy_schema(self) -> ArrowNumpySchema:
753 """Convert to an `ArrowNumpySchema`.
755 Returns
756 -------
757 arrow_numpy_schema : `ArrowNumpySchema`
758 Converted arrow numpy schema.
759 """
760 return ArrowNumpySchema.from_arrow(astropy_to_arrow(self._schema).schema)
762 @property
763 def schema(self) -> atable.Table:
764 return self._schema
766 def __repr__(self) -> str:
767 return repr(self._schema)
769 def __eq__(self, other: object) -> bool:
770 if not isinstance(other, ArrowAstropySchema):
771 return NotImplemented
773 # If this comparison passes then the two tables have the
774 # same column names.
775 if self._schema.dtype != other._schema.dtype:
776 return False
778 for name in self._schema.columns:
779 if not self._schema[name].unit == other._schema[name].unit:
780 return False
781 if not self._schema[name].description == other._schema[name].description:
782 return False
783 if not self._schema[name].format == other._schema[name].format:
784 return False
786 return True
789class ArrowNumpySchema:
790 """Wrapper class for a schema for a numpy ndarray.
792 Parameters
793 ----------
794 numpy_dtype : `numpy.dtype`
795 Numpy dtype to convert.
796 """
798 def __init__(self, numpy_dtype: np.dtype) -> None:
799 self._dtype = numpy_dtype
801 @classmethod
802 def from_arrow(cls, schema: pa.Schema) -> ArrowNumpySchema:
803 """Convert an arrow schema into an `ArrowNumpySchema`.
805 Parameters
806 ----------
807 schema : `pyarrow.Schema`
808 Pyarrow schema to convert.
810 Returns
811 -------
812 numpy_schema : `ArrowNumpySchema`
813 Converted arrow numpy schema.
814 """
815 import numpy as np
817 dtype = _schema_to_dtype_list(schema)
819 return cls(np.dtype(dtype))
821 def to_arrow_astropy_schema(self) -> ArrowAstropySchema:
822 """Convert to an `ArrowAstropySchema`.
824 Returns
825 -------
826 astropy_schema : `ArrowAstropySchema`
827 Converted arrow astropy schema.
828 """
829 import numpy as np
831 return ArrowAstropySchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema)
833 def to_dataframe_schema(self) -> DataFrameSchema:
834 """Convert to a `DataFrameSchema`.
836 Returns
837 -------
838 dataframe_schema : `DataFrameSchema`
839 Converted dataframe schema.
840 """
841 import numpy as np
843 return DataFrameSchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema)
845 def to_arrow_schema(self) -> pa.Schema:
846 """Convert to a `pyarrow.Schema`.
848 Returns
849 -------
850 arrow_schema : `pyarrow.Schema`
851 Converted pyarrow schema.
852 """
853 import numpy as np
855 return numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema
857 @property
858 def schema(self) -> np.dtype:
859 return self._dtype
861 def __repr__(self) -> str:
862 return repr(self._dtype)
864 def __eq__(self, other: object) -> bool:
865 if not isinstance(other, ArrowNumpySchema):
866 return NotImplemented
868 if not self._dtype == other._dtype:
869 return False
871 return True
874def _split_multi_index_column_names(n: int, names: Iterable[str]) -> list[Sequence[str]]:
875 """Split a string that represents a multi-index column.
877 PyArrow maps Pandas' multi-index column names (which are tuples in Python)
878 to flat strings on disk. This routine exists to reconstruct the original
879 tuple.
881 Parameters
882 ----------
883 n : `int`
884 Number of levels in the `pandas.MultiIndex` that is being
885 reconstructed.
886 names : `~collections.abc.Iterable` [`str`]
887 Strings to be split.
889 Returns
890 -------
891 column_names : `list` [`tuple` [`str`]]
892 A list of multi-index column name tuples.
893 """
894 column_names: list[Sequence[str]] = []
896 pattern = re.compile(r"\({}\)".format(", ".join(["'(.*)'"] * n)))
897 for name in names:
898 m = re.search(pattern, name)
899 if m is not None:
900 column_names.append(m.groups())
902 return column_names
905def _standardize_multi_index_columns(
906 pd_index: pd.MultiIndex,
907 columns: Any,
908 stringify: bool = True,
909) -> list[str | Sequence[Any]]:
910 """Transform a dictionary/iterable index from a multi-index column list
911 into a string directly understandable by PyArrow.
913 Parameters
914 ----------
915 pd_index : `pandas.MultiIndex`
916 Pandas multi-index.
917 columns : `list` [`tuple`] or `dict` [`str`, `str` or `list` [`str`]]
918 Columns to standardize.
919 stringify : `bool`, optional
920 Should the column names be stringified?
922 Returns
923 -------
924 names : `list` [`str`]
925 Stringified representation of a multi-index column name.
926 """
927 index_level_names = tuple(pd_index.names)
929 names: list[str | Sequence[Any]] = []
931 if isinstance(columns, list):
932 for requested in columns:
933 if not isinstance(requested, tuple):
934 raise ValueError(
935 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. "
936 f"Instead got a {get_full_type_name(requested)}."
937 )
938 if stringify:
939 names.append(str(requested))
940 else:
941 names.append(requested)
942 else:
943 if not isinstance(columns, collections.abc.Mapping):
944 raise ValueError(
945 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. "
946 f"Instead got a {get_full_type_name(columns)}."
947 )
948 if not set(index_level_names).issuperset(columns.keys()):
949 raise ValueError(
950 f"Cannot use dict with keys {set(columns.keys())} to select columns from {index_level_names}."
951 )
952 factors = [
953 ensure_iterable(columns.get(level, pd_index.levels[i]))
954 for i, level in enumerate(index_level_names)
955 ]
956 for requested in itertools.product(*factors):
957 for i, value in enumerate(requested):
958 if value not in pd_index.levels[i]:
959 raise ValueError(f"Unrecognized value {value!r} for index {index_level_names[i]!r}.")
960 if stringify:
961 names.append(str(requested))
962 else:
963 names.append(requested)
965 return names
968def _apply_astropy_metadata(astropy_table: atable.Table, metadata: dict) -> None:
969 """Apply any astropy metadata from the schema metadata.
971 Parameters
972 ----------
973 astropy_table : `astropy.table.Table`
974 Table to apply metadata.
975 metadata : `dict` [`bytes`]
976 Metadata dict.
977 """
978 from astropy.table import meta
980 meta_yaml = metadata.get(b"table_meta_yaml", None)
981 if meta_yaml:
982 meta_yaml = meta_yaml.decode("UTF8").split("\n")
983 meta_hdr = meta.get_header_from_yaml(meta_yaml)
985 # Set description, format, unit, meta from the column
986 # metadata that was serialized with the table.
987 header_cols = {x["name"]: x for x in meta_hdr["datatype"]}
988 for col in astropy_table.columns.values():
989 for attr in ("description", "format", "unit", "meta"):
990 if attr in header_cols[col.name]:
991 setattr(col, attr, header_cols[col.name][attr])
993 if "meta" in meta_hdr:
994 astropy_table.meta.update(meta_hdr["meta"])
997def _arrow_string_to_numpy_dtype(
998 schema: pa.Schema, name: str, numpy_column: np.ndarray | None = None, default_length: int = 10
999) -> str:
1000 """Get the numpy dtype string associated with an arrow column.
1002 Parameters
1003 ----------
1004 schema : `pyarrow.Schema`
1005 Arrow table schema.
1006 name : `str`
1007 Column name.
1008 numpy_column : `numpy.ndarray`, optional
1009 Column to determine numpy string dtype.
1010 default_length : `int`, optional
1011 Default string length when not in metadata or can be inferred
1012 from column.
1014 Returns
1015 -------
1016 dtype_str : `str`
1017 Numpy dtype string.
1018 """
1019 # Special-case for string and binary columns
1020 md_name = f"lsst::arrow::len::{name}"
1021 strlen = default_length
1022 metadata = schema.metadata if schema.metadata is not None else {}
1023 if (encoded := md_name.encode("UTF-8")) in metadata:
1024 # String/bytes length from header.
1025 strlen = int(schema.metadata[encoded])
1026 elif numpy_column is not None and len(numpy_column) > 0:
1027 strlen = max(len(row) for row in numpy_column)
1029 dtype = f"U{strlen}" if schema.field(name).type == pa.string() else f"|S{strlen}"
1031 return dtype
1034def _append_numpy_string_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None:
1035 """Append numpy string length keys to arrow metadata.
1037 All column types are handled, but the metadata is only modified for
1038 string and byte columns.
1040 Parameters
1041 ----------
1042 metadata : `dict` [`bytes`, `str`]
1043 Metadata dictionary; modified in place.
1044 name : `str`
1045 Column name.
1046 dtype : `np.dtype`
1047 Numpy dtype.
1048 """
1049 import numpy as np
1051 if dtype.type is np.str_:
1052 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize // 4)
1053 elif dtype.type is np.bytes_:
1054 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize)
1057def _append_numpy_multidim_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None:
1058 """Append numpy multi-dimensional shapes to arrow metadata.
1060 All column types are handled, but the metadata is only modified for
1061 multi-dimensional columns.
1063 Parameters
1064 ----------
1065 metadata : `dict` [`bytes`, `str`]
1066 Metadata dictionary; modified in place.
1067 name : `str`
1068 Column name.
1069 dtype : `np.dtype`
1070 Numpy dtype.
1071 """
1072 if len(dtype.shape) > 1:
1073 metadata[f"lsst::arrow::shape::{name}".encode()] = str(dtype.shape)
1076def _multidim_shape_from_metadata(metadata: dict[bytes, bytes], list_size: int, name: str) -> tuple[int, ...]:
1077 """Retrieve the shape from the metadata, if available.
1079 Parameters
1080 ----------
1081 metadata : `dict` [`bytes`, `bytes`]
1082 Metadata dictionary.
1083 list_size : `int`
1084 Size of the list datatype.
1085 name : `str`
1086 Column name.
1088 Returns
1089 -------
1090 shape : `tuple` [`int`]
1091 Shape associated with the column.
1093 Raises
1094 ------
1095 RuntimeError
1096 Raised if metadata is found but has incorrect format.
1097 """
1098 md_name = f"lsst::arrow::shape::{name}"
1099 if (encoded := md_name.encode("UTF-8")) in metadata:
1100 groups = re.search(r"\((.*)\)", metadata[encoded].decode("UTF-8"))
1101 if groups is None:
1102 raise RuntimeError("Illegal value found in metadata.")
1103 shape = tuple(int(x) for x in groups[1].split(",") if x != "")
1104 else:
1105 shape = (list_size,)
1107 return shape
1110def _schema_to_dtype_list(schema: pa.Schema) -> list[tuple[str, tuple[Any] | str]]:
1111 """Convert a pyarrow schema to a numpy dtype.
1113 Parameters
1114 ----------
1115 schema : `pyarrow.Schema`
1116 Input pyarrow schema.
1118 Returns
1119 -------
1120 dtype_list: `list` [`tuple`]
1121 A list with name, type pairs.
1122 """
1123 metadata = schema.metadata if schema.metadata is not None else {}
1125 dtype: list[Any] = []
1126 for name in schema.names:
1127 t = schema.field(name).type
1128 if isinstance(t, pa.FixedSizeListType):
1129 shape = _multidim_shape_from_metadata(metadata, t.list_size, name)
1130 dtype.append((name, (t.value_type.to_pandas_dtype(), shape)))
1131 elif t not in (pa.string(), pa.binary()):
1132 dtype.append((name, t.to_pandas_dtype()))
1133 else:
1134 dtype.append((name, _arrow_string_to_numpy_dtype(schema, name)))
1136 return dtype
1139def _numpy_dtype_to_arrow_types(dtype: np.dtype) -> list[Any]:
1140 """Convert a numpy dtype to a list of arrow types.
1142 Parameters
1143 ----------
1144 dtype : `numpy.dtype`
1145 Numpy dtype to convert.
1147 Returns
1148 -------
1149 type_list : `list` [`object`]
1150 Converted list of arrow types.
1151 """
1152 from math import prod
1154 import numpy as np
1156 type_list: list[Any] = []
1157 if dtype.names is None:
1158 return type_list
1160 for name in dtype.names:
1161 dt = dtype[name]
1162 arrow_type: Any
1163 if len(dt.shape) > 0:
1164 arrow_type = pa.list_(
1165 pa.from_numpy_dtype(cast(tuple[np.dtype, tuple[int, ...]], dt.subdtype)[0].type),
1166 prod(dt.shape),
1167 )
1168 else:
1169 arrow_type = pa.from_numpy_dtype(dt.type)
1170 type_list.append((name, arrow_type))
1172 return type_list
1175def _numpy_dict_to_dtype(numpy_dict: dict[str, np.ndarray]) -> tuple[np.dtype, int]:
1176 """Extract equivalent table dtype from dict of numpy arrays.
1178 Parameters
1179 ----------
1180 numpy_dict : `dict` [`str`, `numpy.ndarray`]
1181 Dict with keys as the column names, values as the arrays.
1183 Returns
1184 -------
1185 dtype : `numpy.dtype`
1186 dtype of equivalent table.
1187 rowcount : `int`
1188 Number of rows in the table.
1190 Raises
1191 ------
1192 ValueError if columns in numpy_dict have unequal numbers of rows.
1193 """
1194 import numpy as np
1196 dtype_list = []
1197 rowcount = 0
1198 for name, col in numpy_dict.items():
1199 if rowcount == 0:
1200 rowcount = len(col)
1201 if len(col) != rowcount:
1202 raise ValueError(f"Column {name} has a different number of rows.")
1203 if len(col.shape) == 1:
1204 dtype_list.append((name, col.dtype))
1205 else:
1206 dtype_list.append((name, (col.dtype, col.shape[1:])))
1207 dtype = np.dtype(dtype_list)
1209 return (dtype, rowcount)
1212def _numpy_style_arrays_to_arrow_arrays(
1213 dtype: np.dtype,
1214 rowcount: int,
1215 np_style_arrays: dict[str, np.ndarray] | np.ndarray | atable.Table,
1216 schema: pa.Schema,
1217) -> list[pa.Array]:
1218 """Convert numpy-style arrays to arrow arrays.
1220 Parameters
1221 ----------
1222 dtype : `numpy.dtype`
1223 Numpy dtype of input table/arrays.
1224 rowcount : `int`
1225 Number of rows in input table/arrays.
1226 np_style_arrays : `dict` [`str`, `np.ndarray`] or `np.ndarray`
1227 or `astropy.table.Table`
1228 Arrays to convert to arrow.
1229 schema : `pyarrow.Schema`
1230 Schema of arrow table.
1232 Returns
1233 -------
1234 arrow_arrays : `list` [`pyarrow.Array`]
1235 List of converted pyarrow arrays.
1236 """
1237 import numpy as np
1239 arrow_arrays: list[pa.Array] = []
1240 if dtype.names is None:
1241 return arrow_arrays
1243 for name in dtype.names:
1244 dt = dtype[name]
1245 val: Any
1246 if len(dt.shape) > 0:
1247 if rowcount > 0:
1248 val = np.split(np_style_arrays[name].ravel(), rowcount)
1249 else:
1250 val = []
1251 else:
1252 val = np_style_arrays[name]
1254 try:
1255 arrow_arrays.append(pa.array(val, type=schema.field(name).type))
1256 except pa.ArrowNotImplementedError as err:
1257 # Check if val is big-endian.
1258 if (np.little_endian and val.dtype.byteorder == ">") or (
1259 not np.little_endian and val.dtype.byteorder == "="
1260 ):
1261 # We need to convert the array to little-endian.
1262 val2 = val.byteswap()
1263 val2.dtype = val2.dtype.newbyteorder("<")
1264 arrow_arrays.append(pa.array(val2, type=schema.field(name).type))
1265 else:
1266 # This failed for some other reason so raise the exception.
1267 raise err
1269 return arrow_arrays
1272def compute_row_group_size(schema: pa.Schema, target_size: int = TARGET_ROW_GROUP_BYTES) -> int:
1273 """Compute approximate row group size for a given arrow schema.
1275 Given a schema, this routine will compute the number of rows in a row group
1276 that targets the persisted size on disk (or smaller). The exact size on
1277 disk depends on the compression settings and ratios; typical binary data
1278 tables will have around 15-20% compression with the pyarrow default
1279 ``snappy`` compression algorithm.
1281 Parameters
1282 ----------
1283 schema : `pyarrow.Schema`
1284 Arrow table schema.
1285 target_size : `int`, optional
1286 The target size (in bytes).
1288 Returns
1289 -------
1290 row_group_size : `int`
1291 Number of rows per row group to hit the target size.
1292 """
1293 bit_width = 0
1295 metadata = schema.metadata if schema.metadata is not None else {}
1297 for name in schema.names:
1298 t = schema.field(name).type
1300 if t in (pa.string(), pa.binary()):
1301 md_name = f"lsst::arrow::len::{name}"
1303 if (encoded := md_name.encode("UTF-8")) in metadata:
1304 # String/bytes length from header.
1305 strlen = int(schema.metadata[encoded])
1306 else:
1307 # We don't know the string width, so guess something.
1308 strlen = 10
1310 # Assuming UTF-8 encoding, and very few wide characters.
1311 t_width = 8 * strlen
1312 elif isinstance(t, pa.FixedSizeListType):
1313 if t.value_type == pa.null():
1314 t_width = 0
1315 else:
1316 t_width = t.list_size * t.value_type.bit_width
1317 elif t == pa.null():
1318 t_width = 0
1319 elif isinstance(t, pa.ListType):
1320 if t.value_type == pa.null():
1321 t_width = 0
1322 else:
1323 # This is a variable length list, just choose
1324 # something arbitrary.
1325 t_width = 10 * t.value_type.bit_width
1326 else:
1327 t_width = t.bit_width
1329 bit_width += t_width
1331 # Insist it is at least 1 byte wide to avoid any divide-by-zero errors.
1332 if bit_width < 8:
1333 bit_width = 8
1335 byte_width = bit_width // 8
1337 return target_size // byte_width