Coverage for python/lsst/daf/butler/formatters/parquet.py: 12%
443 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-03 09:15 +0000
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-03 09:15 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = (
25 "ParquetFormatter",
26 "arrow_to_pandas",
27 "arrow_to_astropy",
28 "arrow_to_numpy",
29 "arrow_to_numpy_dict",
30 "pandas_to_arrow",
31 "pandas_to_astropy",
32 "astropy_to_arrow",
33 "numpy_to_arrow",
34 "numpy_to_astropy",
35 "numpy_dict_to_arrow",
36 "arrow_schema_to_pandas_index",
37 "DataFrameSchema",
38 "ArrowAstropySchema",
39 "ArrowNumpySchema",
40 "compute_row_group_size",
41)
43import collections.abc
44import itertools
45import json
46import re
47from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, cast
49import pyarrow as pa
50import pyarrow.parquet as pq
51from lsst.daf.butler import Formatter
52from lsst.utils.introspection import get_full_type_name
53from lsst.utils.iteration import ensure_iterable
55if TYPE_CHECKING:
56 import astropy.table as atable
57 import numpy as np
58 import pandas as pd
60TARGET_ROW_GROUP_BYTES = 1_000_000_000
63class ParquetFormatter(Formatter):
64 """Interface for reading and writing Arrow Table objects to and from
65 Parquet files.
66 """
68 extension = ".parq"
70 def read(self, component: Optional[str] = None) -> Any:
71 # Docstring inherited from Formatter.read.
72 schema = pq.read_schema(self.fileDescriptor.location.path)
74 if component in ("columns", "schema"):
75 # The schema will be translated to column format
76 # depending on the input type.
77 return schema
78 elif component == "rowcount":
79 # Get the rowcount from the metadata if possible, otherwise count.
80 if b"lsst::arrow::rowcount" in schema.metadata:
81 return int(schema.metadata[b"lsst::arrow::rowcount"])
83 temp_table = pq.read_table(
84 self.fileDescriptor.location.path,
85 columns=[schema.names[0]],
86 use_threads=False,
87 use_pandas_metadata=False,
88 )
90 return len(temp_table[schema.names[0]])
92 par_columns = None
93 if self.fileDescriptor.parameters:
94 par_columns = self.fileDescriptor.parameters.pop("columns", None)
95 if par_columns:
96 has_pandas_multi_index = False
97 if b"pandas" in schema.metadata:
98 md = json.loads(schema.metadata[b"pandas"])
99 if len(md["column_indexes"]) > 1:
100 has_pandas_multi_index = True
102 if not has_pandas_multi_index:
103 # Ensure uniqueness, keeping order.
104 par_columns = list(dict.fromkeys(ensure_iterable(par_columns)))
105 file_columns = [name for name in schema.names if not name.startswith("__")]
107 for par_column in par_columns:
108 if par_column not in file_columns:
109 raise ValueError(
110 f"Column {par_column} specified in parameters not available in parquet file."
111 )
112 else:
113 par_columns = _standardize_multi_index_columns(
114 arrow_schema_to_pandas_index(schema),
115 par_columns,
116 )
118 if len(self.fileDescriptor.parameters):
119 raise ValueError(
120 f"Unsupported parameters {self.fileDescriptor.parameters} in ArrowTable read."
121 )
123 metadata = schema.metadata if schema.metadata is not None else {}
124 arrow_table = pq.read_table(
125 self.fileDescriptor.location.path,
126 columns=par_columns,
127 use_threads=False,
128 use_pandas_metadata=(b"pandas" in metadata),
129 )
131 return arrow_table
133 def write(self, inMemoryDataset: Any) -> None:
134 import numpy as np
135 from astropy.table import Table as astropyTable
137 arrow_table = None
138 if isinstance(inMemoryDataset, pa.Table):
139 # This will be the most likely match.
140 arrow_table = inMemoryDataset
141 elif isinstance(inMemoryDataset, astropyTable):
142 arrow_table = astropy_to_arrow(inMemoryDataset)
143 elif isinstance(inMemoryDataset, np.ndarray):
144 arrow_table = numpy_to_arrow(inMemoryDataset)
145 elif isinstance(inMemoryDataset, dict):
146 try:
147 arrow_table = numpy_dict_to_arrow(inMemoryDataset)
148 except (TypeError, AttributeError) as e:
149 raise ValueError(
150 "Input dict for inMemoryDataset does not appear to be a dict of numpy arrays."
151 ) from e
152 else:
153 if hasattr(inMemoryDataset, "to_parquet"):
154 # This may be a pandas DataFrame
155 try:
156 import pandas as pd
157 except ImportError:
158 pd = None
160 if pd is not None and isinstance(inMemoryDataset, pd.DataFrame):
161 arrow_table = pandas_to_arrow(inMemoryDataset)
163 if arrow_table is None:
164 raise ValueError(
165 f"Unsupported type {get_full_type_name(inMemoryDataset)} of "
166 "inMemoryDataset for ParquetFormatter."
167 )
169 row_group_size = compute_row_group_size(arrow_table.schema)
171 location = self.makeUpdatedLocation(self.fileDescriptor.location)
173 pq.write_table(arrow_table, location.path, row_group_size=row_group_size)
176def arrow_to_pandas(arrow_table: pa.Table) -> pd.DataFrame:
177 """Convert a pyarrow table to a pandas DataFrame.
179 Parameters
180 ----------
181 arrow_table : `pyarrow.Table`
182 Input arrow table to convert. If the table has ``pandas`` metadata
183 in the schema it will be used in the construction of the
184 ``DataFrame``.
186 Returns
187 -------
188 dataframe : `pandas.DataFrame`
189 Converted pandas dataframe.
190 """
191 return arrow_table.to_pandas(use_threads=False, integer_object_nulls=True)
194def arrow_to_astropy(arrow_table: pa.Table) -> atable.Table:
195 """Convert a pyarrow table to an `astropy.Table`.
197 Parameters
198 ----------
199 arrow_table : `pyarrow.Table`
200 Input arrow table to convert. If the table has astropy unit
201 metadata in the schema it will be used in the construction
202 of the ``astropy.Table``.
204 Returns
205 -------
206 table : `astropy.Table`
207 Converted astropy table.
208 """
209 from astropy.table import Table
211 astropy_table = Table(arrow_to_numpy_dict(arrow_table))
213 metadata = arrow_table.schema.metadata if arrow_table.schema.metadata is not None else {}
215 _apply_astropy_metadata(astropy_table, metadata)
217 return astropy_table
220def arrow_to_numpy(arrow_table: pa.Table) -> np.ndarray:
221 """Convert a pyarrow table to a structured numpy array.
223 Parameters
224 ----------
225 arrow_table : `pyarrow.Table`
226 Input arrow table.
228 Returns
229 -------
230 array : `numpy.ndarray` (N,)
231 Numpy array table with N rows and the same column names
232 as the input arrow table.
233 """
234 import numpy as np
236 numpy_dict = arrow_to_numpy_dict(arrow_table)
238 dtype = []
239 for name, col in numpy_dict.items():
240 if len(shape := numpy_dict[name].shape) <= 1:
241 dtype.append((name, col.dtype))
242 else:
243 dtype.append((name, (col.dtype, shape[1:])))
245 array = np.rec.fromarrays(numpy_dict.values(), dtype=dtype)
247 return array
250def arrow_to_numpy_dict(arrow_table: pa.Table) -> dict[str, np.ndarray]:
251 """Convert a pyarrow table to a dict of numpy arrays.
253 Parameters
254 ----------
255 arrow_table : `pyarrow.Table`
256 Input arrow table.
258 Returns
259 -------
260 numpy_dict : `dict` [`str`, `numpy.ndarray`]
261 Dict with keys as the column names, values as the arrays.
262 """
263 import numpy as np
265 schema = arrow_table.schema
266 metadata = schema.metadata if schema.metadata is not None else {}
268 numpy_dict = {}
270 for name in schema.names:
271 t = schema.field(name).type
273 if arrow_table[name].null_count == 0:
274 # Regular non-masked column
275 col = arrow_table[name].to_numpy()
276 else:
277 # For a masked column, we need to ask arrow to fill the null
278 # values with an appropriately typed value before conversion.
279 # Then we apply the mask to get a masked array of the correct type.
281 if t in (pa.string(), pa.binary()):
282 dummy = ""
283 else:
284 dummy = t.to_pandas_dtype()(0)
286 col = np.ma.masked_array(
287 data=arrow_table[name].fill_null(dummy).to_numpy(),
288 mask=arrow_table[name].is_null().to_numpy(),
289 )
291 if t in (pa.string(), pa.binary()):
292 col = col.astype(_arrow_string_to_numpy_dtype(schema, name, col))
293 elif isinstance(t, pa.FixedSizeListType):
294 if len(col) > 0:
295 col = np.stack(col)
296 else:
297 # this is an empty column, and needs to be coerced to type.
298 col = col.astype(t.value_type.to_pandas_dtype())
300 shape = _multidim_shape_from_metadata(metadata, t.list_size, name)
301 col = col.reshape((len(arrow_table), *shape))
303 numpy_dict[name] = col
305 return numpy_dict
308def _numpy_dict_to_numpy(numpy_dict: dict[str, np.ndarray]) -> np.ndarray:
309 """Convert a dict of numpy arrays to a structured numpy array.
311 Parameters
312 ----------
313 numpy_dict : `dict` [`str`, `numpy.ndarray`]
314 Dict with keys as the column names, values as the arrays.
316 Returns
317 -------
318 array : `numpy.ndarray` (N,)
319 Numpy array table with N rows and columns names from the dict keys.
320 """
321 return arrow_to_numpy(numpy_dict_to_arrow(numpy_dict))
324def _numpy_to_numpy_dict(np_array: np.ndarray) -> dict[str, np.ndarray]:
325 """Convert a structured numpy array to a dict of numpy arrays.
327 Parameters
328 ----------
329 np_array : `numpy.ndarray`
330 Input numpy array with multiple fields.
332 Returns
333 -------
334 numpy_dict : `dict` [`str`, `numpy.ndarray`]
335 Dict with keys as the column names, values as the arrays.
336 """
337 return arrow_to_numpy_dict(numpy_to_arrow(np_array))
340def numpy_to_arrow(np_array: np.ndarray) -> pa.Table:
341 """Convert a numpy array table to an arrow table.
343 Parameters
344 ----------
345 np_array : `numpy.ndarray`
346 Input numpy array with multiple fields.
348 Returns
349 -------
350 arrow_table : `pyarrow.Table`
351 Converted arrow table.
352 """
353 type_list = _numpy_dtype_to_arrow_types(np_array.dtype)
355 md = {}
356 md[b"lsst::arrow::rowcount"] = str(len(np_array))
358 for name in np_array.dtype.names:
359 _append_numpy_string_metadata(md, name, np_array.dtype[name])
360 _append_numpy_multidim_metadata(md, name, np_array.dtype[name])
362 schema = pa.schema(type_list, metadata=md)
364 arrays = _numpy_style_arrays_to_arrow_arrays(
365 np_array.dtype,
366 len(np_array),
367 np_array,
368 schema,
369 )
371 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
373 return arrow_table
376def numpy_dict_to_arrow(numpy_dict: dict[str, np.ndarray]) -> pa.Table:
377 """Convert a dict of numpy arrays to an arrow table.
379 Parameters
380 ----------
381 numpy_dict : `dict` [`str`, `numpy.ndarray`]
382 Dict with keys as the column names, values as the arrays.
384 Returns
385 -------
386 arrow_table : `pyarrow.Table`
387 Converted arrow table.
389 Raises
390 ------
391 ValueError if columns in numpy_dict have unequal numbers of rows.
392 """
393 dtype, rowcount = _numpy_dict_to_dtype(numpy_dict)
394 type_list = _numpy_dtype_to_arrow_types(dtype)
396 md = {}
397 md[b"lsst::arrow::rowcount"] = str(rowcount)
399 if dtype.names is not None:
400 for name in dtype.names:
401 _append_numpy_string_metadata(md, name, dtype[name])
402 _append_numpy_multidim_metadata(md, name, dtype[name])
404 schema = pa.schema(type_list, metadata=md)
406 arrays = _numpy_style_arrays_to_arrow_arrays(
407 dtype,
408 rowcount,
409 numpy_dict,
410 schema,
411 )
413 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
415 return arrow_table
418def astropy_to_arrow(astropy_table: atable.Table) -> pa.Table:
419 """Convert an astropy table to an arrow table.
421 Parameters
422 ----------
423 astropy_table : `astropy.Table`
424 Input astropy table.
426 Returns
427 -------
428 arrow_table : `pyarrow.Table`
429 Converted arrow table.
430 """
431 from astropy.table import meta
433 type_list = _numpy_dtype_to_arrow_types(astropy_table.dtype)
435 md = {}
436 md[b"lsst::arrow::rowcount"] = str(len(astropy_table))
438 for name in astropy_table.dtype.names:
439 _append_numpy_string_metadata(md, name, astropy_table.dtype[name])
440 _append_numpy_multidim_metadata(md, name, astropy_table.dtype[name])
442 meta_yaml = meta.get_yaml_from_table(astropy_table)
443 meta_yaml_str = "\n".join(meta_yaml)
444 md[b"table_meta_yaml"] = meta_yaml_str
446 schema = pa.schema(type_list, metadata=md)
448 arrays = _numpy_style_arrays_to_arrow_arrays(
449 astropy_table.dtype,
450 len(astropy_table),
451 astropy_table,
452 schema,
453 )
455 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
457 return arrow_table
460def _astropy_to_numpy_dict(astropy_table: atable.Table) -> dict[str, np.ndarray]:
461 """Convert an astropy table to an arrow table.
463 Parameters
464 ----------
465 astropy_table : `astropy.Table`
466 Input astropy table.
468 Returns
469 -------
470 numpy_dict : `dict` [`str`, `numpy.ndarray`]
471 Dict with keys as the column names, values as the arrays.
472 """
473 return arrow_to_numpy_dict(astropy_to_arrow(astropy_table))
476def pandas_to_arrow(dataframe: pd.DataFrame, default_length: int = 10) -> pa.Table:
477 """Convert a pandas dataframe to an arrow table.
479 Parameters
480 ----------
481 dataframe : `pandas.DataFrame`
482 Input pandas dataframe.
483 default_length : `int`, optional
484 Default string length when not in metadata or can be inferred
485 from column.
487 Returns
488 -------
489 arrow_table : `pyarrow.Table`
490 Converted arrow table.
491 """
492 arrow_table = pa.Table.from_pandas(dataframe)
494 # Update the metadata
495 md = arrow_table.schema.metadata
497 md[b"lsst::arrow::rowcount"] = str(arrow_table.num_rows)
499 # We loop through the arrow table columns because the datatypes have
500 # been checked and converted from pandas objects.
501 for name in arrow_table.column_names:
502 if not name.startswith("__"):
503 if arrow_table[name].type == pa.string():
504 if len(arrow_table[name]) > 0:
505 strlen = max(len(row.as_py()) for row in arrow_table[name] if row.is_valid)
506 else:
507 strlen = default_length
508 md[f"lsst::arrow::len::{name}".encode("UTF-8")] = str(strlen)
510 arrow_table = arrow_table.replace_schema_metadata(md)
512 return arrow_table
515def pandas_to_astropy(dataframe: pd.DataFrame) -> atable.Table:
516 """Convert a pandas dataframe to an astropy table, preserving indexes.
518 Parameters
519 ----------
520 dataframe : `pandas.DataFrame`
521 Input pandas dataframe.
523 Returns
524 -------
525 astropy_table : `astropy.table.Table`
526 Converted astropy table.
527 """
528 import pandas as pd
529 from astropy.table import Table
531 if isinstance(dataframe.columns, pd.MultiIndex):
532 raise ValueError("Cannot convert a multi-index dataframe to an astropy table.")
534 return Table.from_pandas(dataframe, index=True)
537def _pandas_to_numpy_dict(dataframe: pd.DataFrame) -> dict[str, np.ndarray]:
538 """Convert a pandas dataframe to an dict of numpy arrays.
540 Parameters
541 ----------
542 dataframe : `pandas.DataFrame`
543 Input pandas dataframe.
545 Returns
546 -------
547 numpy_dict : `dict` [`str`, `numpy.ndarray`]
548 Dict with keys as the column names, values as the arrays.
549 """
550 return arrow_to_numpy_dict(pandas_to_arrow(dataframe))
553def numpy_to_astropy(np_array: np.ndarray) -> atable.Table:
554 """Convert a numpy table to an astropy table.
556 Parameters
557 ----------
558 np_array : `numpy.ndarray`
559 Input numpy array with multiple fields.
561 Returns
562 -------
563 astropy_table : `astropy.table.Table`
564 Converted astropy table.
565 """
566 from astropy.table import Table
568 return Table(data=np_array, copy=False)
571def arrow_schema_to_pandas_index(schema: pa.Schema) -> pd.Index | pd.MultiIndex:
572 """Convert an arrow schema to a pandas index/multiindex.
574 Parameters
575 ----------
576 schema : `pyarrow.Schema`
577 Input pyarrow schema.
579 Returns
580 -------
581 index : `pandas.Index` or `pandas.MultiIndex`
582 Converted pandas index.
583 """
584 import pandas as pd
586 if b"pandas" in schema.metadata:
587 md = json.loads(schema.metadata[b"pandas"])
588 indexes = md["column_indexes"]
589 len_indexes = len(indexes)
590 else:
591 len_indexes = 0
593 if len_indexes <= 1:
594 return pd.Index(name for name in schema.names if not name.startswith("__"))
595 else:
596 raw_columns = _split_multi_index_column_names(len(indexes), schema.names)
597 return pd.MultiIndex.from_tuples(raw_columns, names=[f["name"] for f in indexes])
600def arrow_schema_to_column_list(schema: pa.Schema) -> list[str]:
601 """Convert an arrow schema to a list of string column names.
603 Parameters
604 ----------
605 schema : `pyarrow.Schema`
606 Input pyarrow schema.
608 Returns
609 -------
610 column_list : `list` [`str`]
611 Converted list of column names.
612 """
613 return [name for name in schema.names]
616class DataFrameSchema:
617 """Wrapper class for a schema for a pandas DataFrame.
619 Parameters
620 ----------
621 dataframe : `pandas.DataFrame`
622 Dataframe to turn into a schema.
623 """
625 def __init__(self, dataframe: pd.DataFrame) -> None:
626 self._schema = dataframe.loc[[False] * len(dataframe)]
628 @classmethod
629 def from_arrow(cls, schema: pa.Schema) -> DataFrameSchema:
630 """Convert an arrow schema into a `DataFrameSchema`.
632 Parameters
633 ----------
634 schema : `pyarrow.Schema`
635 The pyarrow schema to convert.
637 Returns
638 -------
639 dataframe_schema : `DataFrameSchema`
640 Converted dataframe schema.
641 """
642 empty_table = pa.Table.from_pylist([] * len(schema.names), schema=schema)
644 return cls(empty_table.to_pandas())
646 def to_arrow_schema(self) -> pa.Schema:
647 """Convert to an arrow schema.
649 Returns
650 -------
651 arrow_schema : `pyarrow.Schema`
652 Converted pyarrow schema.
653 """
654 arrow_table = pa.Table.from_pandas(self._schema)
656 return arrow_table.schema
658 def to_arrow_numpy_schema(self) -> ArrowNumpySchema:
659 """Convert to an `ArrowNumpySchema`.
661 Returns
662 -------
663 arrow_numpy_schema : `ArrowNumpySchema`
664 Converted arrow numpy schema.
665 """
666 return ArrowNumpySchema.from_arrow(self.to_arrow_schema())
668 def to_arrow_astropy_schema(self) -> ArrowAstropySchema:
669 """Convert to an ArrowAstropySchema.
671 Returns
672 -------
673 arrow_astropy_schema : `ArrowAstropySchema`
674 Converted arrow astropy schema.
675 """
676 return ArrowAstropySchema.from_arrow(self.to_arrow_schema())
678 @property
679 def schema(self) -> np.dtype:
680 return self._schema
682 def __repr__(self) -> str:
683 return repr(self._schema)
685 def __eq__(self, other: object) -> bool:
686 if not isinstance(other, DataFrameSchema):
687 return NotImplemented
689 return self._schema.equals(other._schema)
692class ArrowAstropySchema:
693 """Wrapper class for a schema for an astropy table.
695 Parameters
696 ----------
697 astropy_table : `astropy.table.Table`
698 Input astropy table.
699 """
701 def __init__(self, astropy_table: atable.Table) -> None:
702 self._schema = astropy_table[:0]
704 @classmethod
705 def from_arrow(cls, schema: pa.Schema) -> ArrowAstropySchema:
706 """Convert an arrow schema into a ArrowAstropySchema.
708 Parameters
709 ----------
710 schema : `pyarrow.Schema`
711 Input pyarrow schema.
713 Returns
714 -------
715 astropy_schema : `ArrowAstropySchema`
716 Converted arrow astropy schema.
717 """
718 import numpy as np
719 from astropy.table import Table
721 dtype = _schema_to_dtype_list(schema)
723 data = np.zeros(0, dtype=dtype)
724 astropy_table = Table(data=data)
726 metadata = schema.metadata if schema.metadata is not None else {}
728 _apply_astropy_metadata(astropy_table, metadata)
730 return cls(astropy_table)
732 def to_arrow_schema(self) -> pa.Schema:
733 """Convert to an arrow schema.
735 Returns
736 -------
737 arrow_schema : `pyarrow.Schema`
738 Converted pyarrow schema.
739 """
740 return astropy_to_arrow(self._schema).schema
742 def to_dataframe_schema(self) -> DataFrameSchema:
743 """Convert to a DataFrameSchema.
745 Returns
746 -------
747 dataframe_schema : `DataFrameSchema`
748 Converted dataframe schema.
749 """
750 return DataFrameSchema.from_arrow(astropy_to_arrow(self._schema).schema)
752 def to_arrow_numpy_schema(self) -> ArrowNumpySchema:
753 """Convert to an `ArrowNumpySchema`.
755 Returns
756 -------
757 arrow_numpy_schema : `ArrowNumpySchema`
758 Converted arrow numpy schema.
759 """
760 return ArrowNumpySchema.from_arrow(astropy_to_arrow(self._schema).schema)
762 @property
763 def schema(self) -> atable.Table:
764 return self._schema
766 def __repr__(self) -> str:
767 return repr(self._schema)
769 def __eq__(self, other: object) -> bool:
770 if not isinstance(other, ArrowAstropySchema):
771 return NotImplemented
773 # If this comparison passes then the two tables have the
774 # same column names.
775 if self._schema.dtype != other._schema.dtype:
776 return False
778 for name in self._schema.columns:
779 if not self._schema[name].unit == other._schema[name].unit:
780 return False
781 if not self._schema[name].description == other._schema[name].description:
782 return False
783 if not self._schema[name].format == other._schema[name].format:
784 return False
786 return True
789class ArrowNumpySchema:
790 """Wrapper class for a schema for a numpy ndarray.
792 Parameters
793 ----------
794 numpy_dtype : `numpy.dtype`
795 Numpy dtype to convert.
796 """
798 def __init__(self, numpy_dtype: np.dtype) -> None:
799 self._dtype = numpy_dtype
801 @classmethod
802 def from_arrow(cls, schema: pa.Schema) -> ArrowNumpySchema:
803 """Convert an arrow schema into an `ArrowNumpySchema`.
805 Parameters
806 ----------
807 schema : `pyarrow.Schema`
808 Pyarrow schema to convert.
810 Returns
811 -------
812 numpy_schema : `ArrowNumpySchema`
813 Converted arrow numpy schema.
814 """
815 import numpy as np
817 dtype = _schema_to_dtype_list(schema)
819 return cls(np.dtype(dtype))
821 def to_arrow_astropy_schema(self) -> ArrowAstropySchema:
822 """Convert to an `ArrowAstropySchema`.
824 Returns
825 -------
826 astropy_schema : `ArrowAstropySchema`
827 Converted arrow astropy schema.
828 """
829 import numpy as np
831 return ArrowAstropySchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema)
833 def to_dataframe_schema(self) -> DataFrameSchema:
834 """Convert to a `DataFrameSchema`.
836 Returns
837 -------
838 dataframe_schema : `DataFrameSchema`
839 Converted dataframe schema.
840 """
841 import numpy as np
843 return DataFrameSchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema)
845 def to_arrow_schema(self) -> pa.Schema:
846 """Convert to a `pyarrow.Schema`.
848 Returns
849 -------
850 arrow_schema : `pyarrow.Schema`
851 Converted pyarrow schema.
852 """
853 import numpy as np
855 return numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema
857 @property
858 def schema(self) -> np.dtype:
859 return self._dtype
861 def __repr__(self) -> str:
862 return repr(self._dtype)
864 def __eq__(self, other: object) -> bool:
865 if not isinstance(other, ArrowNumpySchema):
866 return NotImplemented
868 if not self._dtype == other._dtype:
869 return False
871 return True
874def _split_multi_index_column_names(n: int, names: Iterable[str]) -> List[Sequence[str]]:
875 """Split a string that represents a multi-index column.
877 PyArrow maps Pandas' multi-index column names (which are tuples in Python)
878 to flat strings on disk. This routine exists to reconstruct the original
879 tuple.
881 Parameters
882 ----------
883 n : `int`
884 Number of levels in the `pandas.MultiIndex` that is being
885 reconstructed.
886 names : `~collections.abc.Iterable` [`str`]
887 Strings to be split.
889 Returns
890 -------
891 column_names : `list` [`tuple` [`str`]]
892 A list of multi-index column name tuples.
893 """
894 column_names: List[Sequence[str]] = []
896 pattern = re.compile(r"\({}\)".format(", ".join(["'(.*)'"] * n)))
897 for name in names:
898 m = re.search(pattern, name)
899 if m is not None:
900 column_names.append(m.groups())
902 return column_names
905def _standardize_multi_index_columns(
906 pd_index: pd.MultiIndex,
907 columns: Any,
908 stringify: bool = True,
909) -> list[str | Sequence[Any]]:
910 """Transform a dictionary/iterable index from a multi-index column list
911 into a string directly understandable by PyArrow.
913 Parameters
914 ----------
915 pd_index : `pandas.MultiIndex`
916 Pandas multi-index.
917 columns : `list` [`tuple`] or `dict` [`str`, `str` or `list` [`str`]]
918 Columns to standardize.
919 stringify : `bool`, optional
920 Should the column names be stringified?
922 Returns
923 -------
924 names : `list` [`str`]
925 Stringified representation of a multi-index column name.
926 """
927 index_level_names = tuple(pd_index.names)
929 names: list[str | Sequence[Any]] = []
931 if isinstance(columns, list):
932 for requested in columns:
933 if not isinstance(requested, tuple):
934 raise ValueError(
935 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. "
936 f"Instead got a {get_full_type_name(requested)}."
937 )
938 if stringify:
939 names.append(str(requested))
940 else:
941 names.append(requested)
942 else:
943 if not isinstance(columns, collections.abc.Mapping):
944 raise ValueError(
945 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. "
946 f"Instead got a {get_full_type_name(columns)}."
947 )
948 if not set(index_level_names).issuperset(columns.keys()):
949 raise ValueError(
950 f"Cannot use dict with keys {set(columns.keys())} to select columns from {index_level_names}."
951 )
952 factors = [
953 ensure_iterable(columns.get(level, pd_index.levels[i]))
954 for i, level in enumerate(index_level_names)
955 ]
956 for requested in itertools.product(*factors):
957 for i, value in enumerate(requested):
958 if value not in pd_index.levels[i]:
959 raise ValueError(f"Unrecognized value {value!r} for index {index_level_names[i]!r}.")
960 if stringify:
961 names.append(str(requested))
962 else:
963 names.append(requested)
965 return names
968def _apply_astropy_metadata(astropy_table: atable.Table, metadata: dict) -> None:
969 """Apply any astropy metadata from the schema metadata.
971 Parameters
972 ----------
973 astropy_table : `astropy.table.Table`
974 Table to apply metadata.
975 metadata : `dict` [`bytes`]
976 Metadata dict.
977 """
978 from astropy.table import meta
980 meta_yaml = metadata.get(b"table_meta_yaml", None)
981 if meta_yaml:
982 meta_yaml = meta_yaml.decode("UTF8").split("\n")
983 meta_hdr = meta.get_header_from_yaml(meta_yaml)
985 # Set description, format, unit, meta from the column
986 # metadata that was serialized with the table.
987 header_cols = {x["name"]: x for x in meta_hdr["datatype"]}
988 for col in astropy_table.columns.values():
989 for attr in ("description", "format", "unit", "meta"):
990 if attr in header_cols[col.name]:
991 setattr(col, attr, header_cols[col.name][attr])
993 if "meta" in meta_hdr:
994 astropy_table.meta.update(meta_hdr["meta"])
997def _arrow_string_to_numpy_dtype(
998 schema: pa.Schema, name: str, numpy_column: np.ndarray | None = None, default_length: int = 10
999) -> str:
1000 """Get the numpy dtype string associated with an arrow column.
1002 Parameters
1003 ----------
1004 schema : `pyarrow.Schema`
1005 Arrow table schema.
1006 name : `str`
1007 Column name.
1008 numpy_column : `numpy.ndarray`, optional
1009 Column to determine numpy string dtype.
1010 default_length : `int`, optional
1011 Default string length when not in metadata or can be inferred
1012 from column.
1014 Returns
1015 -------
1016 dtype_str : `str`
1017 Numpy dtype string.
1018 """
1019 # Special-case for string and binary columns
1020 md_name = f"lsst::arrow::len::{name}"
1021 strlen = default_length
1022 metadata = schema.metadata if schema.metadata is not None else {}
1023 if (encoded := md_name.encode("UTF-8")) in metadata:
1024 # String/bytes length from header.
1025 strlen = int(schema.metadata[encoded])
1026 elif numpy_column is not None:
1027 if len(numpy_column) > 0:
1028 strlen = max(len(row) for row in numpy_column)
1030 dtype = f"U{strlen}" if schema.field(name).type == pa.string() else f"|S{strlen}"
1032 return dtype
1035def _append_numpy_string_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None:
1036 """Append numpy string length keys to arrow metadata.
1038 All column types are handled, but the metadata is only modified for
1039 string and byte columns.
1041 Parameters
1042 ----------
1043 metadata : `dict` [`bytes`, `str`]
1044 Metadata dictionary; modified in place.
1045 name : `str`
1046 Column name.
1047 dtype : `np.dtype`
1048 Numpy dtype.
1049 """
1050 import numpy as np
1052 if dtype.type is np.str_:
1053 metadata[f"lsst::arrow::len::{name}".encode("UTF-8")] = str(dtype.itemsize // 4)
1054 elif dtype.type is np.bytes_:
1055 metadata[f"lsst::arrow::len::{name}".encode("UTF-8")] = str(dtype.itemsize)
1058def _append_numpy_multidim_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None:
1059 """Append numpy multi-dimensional shapes to arrow metadata.
1061 All column types are handled, but the metadata is only modified for
1062 multi-dimensional columns.
1064 Parameters
1065 ----------
1066 metadata : `dict` [`bytes`, `str`]
1067 Metadata dictionary; modified in place.
1068 name : `str`
1069 Column name.
1070 dtype : `np.dtype`
1071 Numpy dtype.
1072 """
1073 if len(dtype.shape) > 1:
1074 metadata[f"lsst::arrow::shape::{name}".encode("UTF-8")] = str(dtype.shape)
1077def _multidim_shape_from_metadata(metadata: dict[bytes, bytes], list_size: int, name: str) -> tuple[int, ...]:
1078 """Retrieve the shape from the metadata, if available.
1080 Parameters
1081 ----------
1082 metadata : `dict` [`bytes`, `bytes`]
1083 Metadata dictionary.
1084 list_size : `int`
1085 Size of the list datatype.
1086 name : `str`
1087 Column name.
1089 Returns
1090 -------
1091 shape : `tuple` [`int`]
1092 Shape associated with the column.
1094 Raises
1095 ------
1096 RuntimeError
1097 Raised if metadata is found but has incorrect format.
1098 """
1099 md_name = f"lsst::arrow::shape::{name}"
1100 if (encoded := md_name.encode("UTF-8")) in metadata:
1101 groups = re.search(r"\((.*)\)", metadata[encoded].decode("UTF-8"))
1102 if groups is None:
1103 raise RuntimeError("Illegal value found in metadata.")
1104 shape = tuple(int(x) for x in groups[1].split(",") if x != "")
1105 else:
1106 shape = (list_size,)
1108 return shape
1111def _schema_to_dtype_list(schema: pa.Schema) -> list[tuple[str, tuple[Any] | str]]:
1112 """Convert a pyarrow schema to a numpy dtype.
1114 Parameters
1115 ----------
1116 schema : `pyarrow.Schema`
1117 Input pyarrow schema.
1119 Returns
1120 -------
1121 dtype_list: `list` [`tuple`]
1122 A list with name, type pairs.
1123 """
1124 metadata = schema.metadata if schema.metadata is not None else {}
1126 dtype: list[Any] = []
1127 for name in schema.names:
1128 t = schema.field(name).type
1129 if isinstance(t, pa.FixedSizeListType):
1130 shape = _multidim_shape_from_metadata(metadata, t.list_size, name)
1131 dtype.append((name, (t.value_type.to_pandas_dtype(), shape)))
1132 elif t not in (pa.string(), pa.binary()):
1133 dtype.append((name, t.to_pandas_dtype()))
1134 else:
1135 dtype.append((name, _arrow_string_to_numpy_dtype(schema, name)))
1137 return dtype
1140def _numpy_dtype_to_arrow_types(dtype: np.dtype) -> list[Any]:
1141 """Convert a numpy dtype to a list of arrow types.
1143 Parameters
1144 ----------
1145 dtype : `numpy.dtype`
1146 Numpy dtype to convert.
1148 Returns
1149 -------
1150 type_list : `list` [`object`]
1151 Converted list of arrow types.
1152 """
1153 from math import prod
1155 import numpy as np
1157 type_list: list[Any] = []
1158 if dtype.names is None:
1159 return type_list
1161 for name in dtype.names:
1162 dt = dtype[name]
1163 arrow_type: Any
1164 if len(dt.shape) > 0:
1165 arrow_type = pa.list_(
1166 pa.from_numpy_dtype(cast(tuple[np.dtype, tuple[int, ...]], dt.subdtype)[0].type),
1167 prod(dt.shape),
1168 )
1169 else:
1170 arrow_type = pa.from_numpy_dtype(dt.type)
1171 type_list.append((name, arrow_type))
1173 return type_list
1176def _numpy_dict_to_dtype(numpy_dict: dict[str, np.ndarray]) -> tuple[np.dtype, int]:
1177 """Extract equivalent table dtype from dict of numpy arrays.
1179 Parameters
1180 ----------
1181 numpy_dict : `dict` [`str`, `numpy.ndarray`]
1182 Dict with keys as the column names, values as the arrays.
1184 Returns
1185 -------
1186 dtype : `numpy.dtype`
1187 dtype of equivalent table.
1188 rowcount : `int`
1189 Number of rows in the table.
1191 Raises
1192 ------
1193 ValueError if columns in numpy_dict have unequal numbers of rows.
1194 """
1195 import numpy as np
1197 dtype_list = []
1198 rowcount = 0
1199 for name, col in numpy_dict.items():
1200 if rowcount == 0:
1201 rowcount = len(col)
1202 if len(col) != rowcount:
1203 raise ValueError(f"Column {name} has a different number of rows.")
1204 if len(col.shape) == 1:
1205 dtype_list.append((name, col.dtype))
1206 else:
1207 dtype_list.append((name, (col.dtype, col.shape[1:])))
1208 dtype = np.dtype(dtype_list)
1210 return (dtype, rowcount)
1213def _numpy_style_arrays_to_arrow_arrays(
1214 dtype: np.dtype,
1215 rowcount: int,
1216 np_style_arrays: dict[str, np.ndarray] | np.ndarray | atable.Table,
1217 schema: pa.Schema,
1218) -> list[pa.Array]:
1219 """Convert numpy-style arrays to arrow arrays.
1221 Parameters
1222 ----------
1223 dtype : `numpy.dtype`
1224 Numpy dtype of input table/arrays.
1225 rowcount : `int`
1226 Number of rows in input table/arrays.
1227 np_style_arrays : `dict` [`str`, `np.ndarray`] or `np.ndarray`
1228 or `astropy.table.Table`
1229 Arrays to convert to arrow.
1230 schema : `pyarrow.Schema`
1231 Schema of arrow table.
1233 Returns
1234 -------
1235 arrow_arrays : `list` [`pyarrow.Array`]
1236 List of converted pyarrow arrays.
1237 """
1238 import numpy as np
1240 arrow_arrays: list[pa.Array] = []
1241 if dtype.names is None:
1242 return arrow_arrays
1244 for name in dtype.names:
1245 dt = dtype[name]
1246 val: Any
1247 if len(dt.shape) > 0:
1248 if rowcount > 0:
1249 val = np.split(np_style_arrays[name].ravel(), rowcount)
1250 else:
1251 val = []
1252 else:
1253 val = np_style_arrays[name]
1255 try:
1256 arrow_arrays.append(pa.array(val, type=schema.field(name).type))
1257 except pa.ArrowNotImplementedError as err:
1258 # Check if val is big-endian.
1259 if (np.little_endian and val.dtype.byteorder == ">") or (
1260 not np.little_endian and val.dtype.byteorder == "="
1261 ):
1262 # We need to convert the array to little-endian.
1263 val2 = val.byteswap()
1264 val2.dtype = val2.dtype.newbyteorder("<")
1265 arrow_arrays.append(pa.array(val2, type=schema.field(name).type))
1266 else:
1267 # This failed for some other reason so raise the exception.
1268 raise err
1270 return arrow_arrays
1273def compute_row_group_size(schema: pa.Schema, target_size: int = TARGET_ROW_GROUP_BYTES) -> int:
1274 """Compute approximate row group size for a given arrow schema.
1276 Given a schema, this routine will compute the number of rows in a row group
1277 that targets the persisted size on disk (or smaller). The exact size on
1278 disk depends on the compression settings and ratios; typical binary data
1279 tables will have around 15-20% compression with the pyarrow default
1280 ``snappy`` compression algorithm.
1282 Parameters
1283 ----------
1284 schema : `pyarrow.Schema`
1285 Arrow table schema.
1286 target_size : `int`, optional
1287 The target size (in bytes).
1289 Returns
1290 -------
1291 row_group_size : `int`
1292 Number of rows per row group to hit the target size.
1293 """
1294 bit_width = 0
1296 metadata = schema.metadata if schema.metadata is not None else {}
1298 for name in schema.names:
1299 t = schema.field(name).type
1301 if t in (pa.string(), pa.binary()):
1302 md_name = f"lsst::arrow::len::{name}"
1304 if (encoded := md_name.encode("UTF-8")) in metadata:
1305 # String/bytes length from header.
1306 strlen = int(schema.metadata[encoded])
1307 else:
1308 # We don't know the string width, so guess something.
1309 strlen = 10
1311 # Assuming UTF-8 encoding, and very few wide characters.
1312 t_width = 8 * strlen
1313 elif isinstance(t, pa.FixedSizeListType):
1314 if t.value_type == pa.null():
1315 t_width = 0
1316 else:
1317 t_width = t.list_size * t.value_type.bit_width
1318 elif t == pa.null():
1319 t_width = 0
1320 elif isinstance(t, pa.ListType):
1321 if t.value_type == pa.null():
1322 t_width = 0
1323 else:
1324 # This is a variable length list, just choose
1325 # something arbitrary.
1326 t_width = 10 * t.value_type.bit_width
1327 else:
1328 t_width = t.bit_width
1330 bit_width += t_width
1332 # Insist it is at least 1 byte wide to avoid any divide-by-zero errors.
1333 if bit_width < 8:
1334 bit_width = 8
1336 byte_width = bit_width // 8
1338 return target_size // byte_width