Coverage for python/lsst/daf/butler/formatters/parquet.py: 12%
444 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-12 10:56 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-12 10:56 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = (
25 "ParquetFormatter",
26 "arrow_to_pandas",
27 "arrow_to_astropy",
28 "arrow_to_numpy",
29 "arrow_to_numpy_dict",
30 "pandas_to_arrow",
31 "pandas_to_astropy",
32 "astropy_to_arrow",
33 "numpy_to_arrow",
34 "numpy_to_astropy",
35 "numpy_dict_to_arrow",
36 "arrow_schema_to_pandas_index",
37 "DataFrameSchema",
38 "ArrowAstropySchema",
39 "ArrowNumpySchema",
40 "compute_row_group_size",
41)
43import collections.abc
44import itertools
45import json
46import re
47from collections.abc import Iterable, Sequence
48from typing import TYPE_CHECKING, Any, cast
50import pyarrow as pa
51import pyarrow.parquet as pq
52from lsst.daf.butler import Formatter
53from lsst.utils.introspection import get_full_type_name
54from lsst.utils.iteration import ensure_iterable
56if TYPE_CHECKING:
57 import astropy.table as atable
58 import numpy as np
59 import pandas as pd
61TARGET_ROW_GROUP_BYTES = 1_000_000_000
64class ParquetFormatter(Formatter):
65 """Interface for reading and writing Arrow Table objects to and from
66 Parquet files.
67 """
69 extension = ".parq"
71 def read(self, component: str | None = None) -> Any:
72 # Docstring inherited from Formatter.read.
73 schema = pq.read_schema(self.fileDescriptor.location.path)
75 if component in ("columns", "schema"):
76 # The schema will be translated to column format
77 # depending on the input type.
78 return schema
79 elif component == "rowcount":
80 # Get the rowcount from the metadata if possible, otherwise count.
81 if b"lsst::arrow::rowcount" in schema.metadata:
82 return int(schema.metadata[b"lsst::arrow::rowcount"])
84 temp_table = pq.read_table(
85 self.fileDescriptor.location.path,
86 columns=[schema.names[0]],
87 use_threads=False,
88 use_pandas_metadata=False,
89 )
91 return len(temp_table[schema.names[0]])
93 par_columns = None
94 if self.fileDescriptor.parameters:
95 par_columns = self.fileDescriptor.parameters.pop("columns", None)
96 if par_columns:
97 has_pandas_multi_index = False
98 if b"pandas" in schema.metadata:
99 md = json.loads(schema.metadata[b"pandas"])
100 if len(md["column_indexes"]) > 1:
101 has_pandas_multi_index = True
103 if not has_pandas_multi_index:
104 # Ensure uniqueness, keeping order.
105 par_columns = list(dict.fromkeys(ensure_iterable(par_columns)))
106 file_columns = [name for name in schema.names if not name.startswith("__")]
108 for par_column in par_columns:
109 if par_column not in file_columns:
110 raise ValueError(
111 f"Column {par_column} specified in parameters not available in parquet file."
112 )
113 else:
114 par_columns = _standardize_multi_index_columns(
115 arrow_schema_to_pandas_index(schema),
116 par_columns,
117 )
119 if len(self.fileDescriptor.parameters):
120 raise ValueError(
121 f"Unsupported parameters {self.fileDescriptor.parameters} in ArrowTable read."
122 )
124 metadata = schema.metadata if schema.metadata is not None else {}
125 arrow_table = pq.read_table(
126 self.fileDescriptor.location.path,
127 columns=par_columns,
128 use_threads=False,
129 use_pandas_metadata=(b"pandas" in metadata),
130 )
132 return arrow_table
134 def write(self, inMemoryDataset: Any) -> None:
135 import numpy as np
136 from astropy.table import Table as astropyTable
138 arrow_table = None
139 if isinstance(inMemoryDataset, pa.Table):
140 # This will be the most likely match.
141 arrow_table = inMemoryDataset
142 elif isinstance(inMemoryDataset, astropyTable):
143 arrow_table = astropy_to_arrow(inMemoryDataset)
144 elif isinstance(inMemoryDataset, np.ndarray):
145 arrow_table = numpy_to_arrow(inMemoryDataset)
146 elif isinstance(inMemoryDataset, dict):
147 try:
148 arrow_table = numpy_dict_to_arrow(inMemoryDataset)
149 except (TypeError, AttributeError) as e:
150 raise ValueError(
151 "Input dict for inMemoryDataset does not appear to be a dict of numpy arrays."
152 ) from e
153 else:
154 if hasattr(inMemoryDataset, "to_parquet"):
155 # This may be a pandas DataFrame
156 try:
157 import pandas as pd
158 except ImportError:
159 pd = None
161 if pd is not None and isinstance(inMemoryDataset, pd.DataFrame):
162 arrow_table = pandas_to_arrow(inMemoryDataset)
164 if arrow_table is None:
165 raise ValueError(
166 f"Unsupported type {get_full_type_name(inMemoryDataset)} of "
167 "inMemoryDataset for ParquetFormatter."
168 )
170 row_group_size = compute_row_group_size(arrow_table.schema)
172 location = self.makeUpdatedLocation(self.fileDescriptor.location)
174 pq.write_table(arrow_table, location.path, row_group_size=row_group_size)
177def arrow_to_pandas(arrow_table: pa.Table) -> pd.DataFrame:
178 """Convert a pyarrow table to a pandas DataFrame.
180 Parameters
181 ----------
182 arrow_table : `pyarrow.Table`
183 Input arrow table to convert. If the table has ``pandas`` metadata
184 in the schema it will be used in the construction of the
185 ``DataFrame``.
187 Returns
188 -------
189 dataframe : `pandas.DataFrame`
190 Converted pandas dataframe.
191 """
192 return arrow_table.to_pandas(use_threads=False, integer_object_nulls=True)
195def arrow_to_astropy(arrow_table: pa.Table) -> atable.Table:
196 """Convert a pyarrow table to an `astropy.Table`.
198 Parameters
199 ----------
200 arrow_table : `pyarrow.Table`
201 Input arrow table to convert. If the table has astropy unit
202 metadata in the schema it will be used in the construction
203 of the ``astropy.Table``.
205 Returns
206 -------
207 table : `astropy.Table`
208 Converted astropy table.
209 """
210 from astropy.table import Table
212 astropy_table = Table(arrow_to_numpy_dict(arrow_table))
214 metadata = arrow_table.schema.metadata if arrow_table.schema.metadata is not None else {}
216 _apply_astropy_metadata(astropy_table, metadata)
218 return astropy_table
221def arrow_to_numpy(arrow_table: pa.Table) -> np.ndarray:
222 """Convert a pyarrow table to a structured numpy array.
224 Parameters
225 ----------
226 arrow_table : `pyarrow.Table`
227 Input arrow table.
229 Returns
230 -------
231 array : `numpy.ndarray` (N,)
232 Numpy array table with N rows and the same column names
233 as the input arrow table.
234 """
235 import numpy as np
237 numpy_dict = arrow_to_numpy_dict(arrow_table)
239 dtype = []
240 for name, col in numpy_dict.items():
241 if len(shape := numpy_dict[name].shape) <= 1:
242 dtype.append((name, col.dtype))
243 else:
244 dtype.append((name, (col.dtype, shape[1:])))
246 array = np.rec.fromarrays(numpy_dict.values(), dtype=dtype)
248 return array
251def arrow_to_numpy_dict(arrow_table: pa.Table) -> dict[str, np.ndarray]:
252 """Convert a pyarrow table to a dict of numpy arrays.
254 Parameters
255 ----------
256 arrow_table : `pyarrow.Table`
257 Input arrow table.
259 Returns
260 -------
261 numpy_dict : `dict` [`str`, `numpy.ndarray`]
262 Dict with keys as the column names, values as the arrays.
263 """
264 import numpy as np
266 schema = arrow_table.schema
267 metadata = schema.metadata if schema.metadata is not None else {}
269 numpy_dict = {}
271 for name in schema.names:
272 t = schema.field(name).type
274 if arrow_table[name].null_count == 0:
275 # Regular non-masked column
276 col = arrow_table[name].to_numpy()
277 else:
278 # For a masked column, we need to ask arrow to fill the null
279 # values with an appropriately typed value before conversion.
280 # Then we apply the mask to get a masked array of the correct type.
282 if t in (pa.string(), pa.binary()):
283 dummy = ""
284 else:
285 dummy = t.to_pandas_dtype()(0)
287 col = np.ma.masked_array(
288 data=arrow_table[name].fill_null(dummy).to_numpy(),
289 mask=arrow_table[name].is_null().to_numpy(),
290 )
292 if t in (pa.string(), pa.binary()):
293 col = col.astype(_arrow_string_to_numpy_dtype(schema, name, col))
294 elif isinstance(t, pa.FixedSizeListType):
295 if len(col) > 0:
296 col = np.stack(col)
297 else:
298 # this is an empty column, and needs to be coerced to type.
299 col = col.astype(t.value_type.to_pandas_dtype())
301 shape = _multidim_shape_from_metadata(metadata, t.list_size, name)
302 col = col.reshape((len(arrow_table), *shape))
304 numpy_dict[name] = col
306 return numpy_dict
309def _numpy_dict_to_numpy(numpy_dict: dict[str, np.ndarray]) -> np.ndarray:
310 """Convert a dict of numpy arrays to a structured numpy array.
312 Parameters
313 ----------
314 numpy_dict : `dict` [`str`, `numpy.ndarray`]
315 Dict with keys as the column names, values as the arrays.
317 Returns
318 -------
319 array : `numpy.ndarray` (N,)
320 Numpy array table with N rows and columns names from the dict keys.
321 """
322 return arrow_to_numpy(numpy_dict_to_arrow(numpy_dict))
325def _numpy_to_numpy_dict(np_array: np.ndarray) -> dict[str, np.ndarray]:
326 """Convert a structured numpy array to a dict of numpy arrays.
328 Parameters
329 ----------
330 np_array : `numpy.ndarray`
331 Input numpy array with multiple fields.
333 Returns
334 -------
335 numpy_dict : `dict` [`str`, `numpy.ndarray`]
336 Dict with keys as the column names, values as the arrays.
337 """
338 return arrow_to_numpy_dict(numpy_to_arrow(np_array))
341def numpy_to_arrow(np_array: np.ndarray) -> pa.Table:
342 """Convert a numpy array table to an arrow table.
344 Parameters
345 ----------
346 np_array : `numpy.ndarray`
347 Input numpy array with multiple fields.
349 Returns
350 -------
351 arrow_table : `pyarrow.Table`
352 Converted arrow table.
353 """
354 type_list = _numpy_dtype_to_arrow_types(np_array.dtype)
356 md = {}
357 md[b"lsst::arrow::rowcount"] = str(len(np_array))
359 for name in np_array.dtype.names:
360 _append_numpy_string_metadata(md, name, np_array.dtype[name])
361 _append_numpy_multidim_metadata(md, name, np_array.dtype[name])
363 schema = pa.schema(type_list, metadata=md)
365 arrays = _numpy_style_arrays_to_arrow_arrays(
366 np_array.dtype,
367 len(np_array),
368 np_array,
369 schema,
370 )
372 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
374 return arrow_table
377def numpy_dict_to_arrow(numpy_dict: dict[str, np.ndarray]) -> pa.Table:
378 """Convert a dict of numpy arrays to an arrow table.
380 Parameters
381 ----------
382 numpy_dict : `dict` [`str`, `numpy.ndarray`]
383 Dict with keys as the column names, values as the arrays.
385 Returns
386 -------
387 arrow_table : `pyarrow.Table`
388 Converted arrow table.
390 Raises
391 ------
392 ValueError if columns in numpy_dict have unequal numbers of rows.
393 """
394 dtype, rowcount = _numpy_dict_to_dtype(numpy_dict)
395 type_list = _numpy_dtype_to_arrow_types(dtype)
397 md = {}
398 md[b"lsst::arrow::rowcount"] = str(rowcount)
400 if dtype.names is not None:
401 for name in dtype.names:
402 _append_numpy_string_metadata(md, name, dtype[name])
403 _append_numpy_multidim_metadata(md, name, dtype[name])
405 schema = pa.schema(type_list, metadata=md)
407 arrays = _numpy_style_arrays_to_arrow_arrays(
408 dtype,
409 rowcount,
410 numpy_dict,
411 schema,
412 )
414 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
416 return arrow_table
419def astropy_to_arrow(astropy_table: atable.Table) -> pa.Table:
420 """Convert an astropy table to an arrow table.
422 Parameters
423 ----------
424 astropy_table : `astropy.Table`
425 Input astropy table.
427 Returns
428 -------
429 arrow_table : `pyarrow.Table`
430 Converted arrow table.
431 """
432 from astropy.table import meta
434 type_list = _numpy_dtype_to_arrow_types(astropy_table.dtype)
436 md = {}
437 md[b"lsst::arrow::rowcount"] = str(len(astropy_table))
439 for name in astropy_table.dtype.names:
440 _append_numpy_string_metadata(md, name, astropy_table.dtype[name])
441 _append_numpy_multidim_metadata(md, name, astropy_table.dtype[name])
443 meta_yaml = meta.get_yaml_from_table(astropy_table)
444 meta_yaml_str = "\n".join(meta_yaml)
445 md[b"table_meta_yaml"] = meta_yaml_str
447 schema = pa.schema(type_list, metadata=md)
449 arrays = _numpy_style_arrays_to_arrow_arrays(
450 astropy_table.dtype,
451 len(astropy_table),
452 astropy_table,
453 schema,
454 )
456 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
458 return arrow_table
461def _astropy_to_numpy_dict(astropy_table: atable.Table) -> dict[str, np.ndarray]:
462 """Convert an astropy table to an arrow table.
464 Parameters
465 ----------
466 astropy_table : `astropy.Table`
467 Input astropy table.
469 Returns
470 -------
471 numpy_dict : `dict` [`str`, `numpy.ndarray`]
472 Dict with keys as the column names, values as the arrays.
473 """
474 return arrow_to_numpy_dict(astropy_to_arrow(astropy_table))
477def pandas_to_arrow(dataframe: pd.DataFrame, default_length: int = 10) -> pa.Table:
478 """Convert a pandas dataframe to an arrow table.
480 Parameters
481 ----------
482 dataframe : `pandas.DataFrame`
483 Input pandas dataframe.
484 default_length : `int`, optional
485 Default string length when not in metadata or can be inferred
486 from column.
488 Returns
489 -------
490 arrow_table : `pyarrow.Table`
491 Converted arrow table.
492 """
493 arrow_table = pa.Table.from_pandas(dataframe)
495 # Update the metadata
496 md = arrow_table.schema.metadata
498 md[b"lsst::arrow::rowcount"] = str(arrow_table.num_rows)
500 # We loop through the arrow table columns because the datatypes have
501 # been checked and converted from pandas objects.
502 for name in arrow_table.column_names:
503 if not name.startswith("__"):
504 if arrow_table[name].type == pa.string():
505 if len(arrow_table[name]) > 0:
506 strlen = max(len(row.as_py()) for row in arrow_table[name] if row.is_valid)
507 else:
508 strlen = default_length
509 md[f"lsst::arrow::len::{name}".encode()] = str(strlen)
511 arrow_table = arrow_table.replace_schema_metadata(md)
513 return arrow_table
516def pandas_to_astropy(dataframe: pd.DataFrame) -> atable.Table:
517 """Convert a pandas dataframe to an astropy table, preserving indexes.
519 Parameters
520 ----------
521 dataframe : `pandas.DataFrame`
522 Input pandas dataframe.
524 Returns
525 -------
526 astropy_table : `astropy.table.Table`
527 Converted astropy table.
528 """
529 import pandas as pd
530 from astropy.table import Table
532 if isinstance(dataframe.columns, pd.MultiIndex):
533 raise ValueError("Cannot convert a multi-index dataframe to an astropy table.")
535 return Table.from_pandas(dataframe, index=True)
538def _pandas_to_numpy_dict(dataframe: pd.DataFrame) -> dict[str, np.ndarray]:
539 """Convert a pandas dataframe to an dict of numpy arrays.
541 Parameters
542 ----------
543 dataframe : `pandas.DataFrame`
544 Input pandas dataframe.
546 Returns
547 -------
548 numpy_dict : `dict` [`str`, `numpy.ndarray`]
549 Dict with keys as the column names, values as the arrays.
550 """
551 return arrow_to_numpy_dict(pandas_to_arrow(dataframe))
554def numpy_to_astropy(np_array: np.ndarray) -> atable.Table:
555 """Convert a numpy table to an astropy table.
557 Parameters
558 ----------
559 np_array : `numpy.ndarray`
560 Input numpy array with multiple fields.
562 Returns
563 -------
564 astropy_table : `astropy.table.Table`
565 Converted astropy table.
566 """
567 from astropy.table import Table
569 return Table(data=np_array, copy=False)
572def arrow_schema_to_pandas_index(schema: pa.Schema) -> pd.Index | pd.MultiIndex:
573 """Convert an arrow schema to a pandas index/multiindex.
575 Parameters
576 ----------
577 schema : `pyarrow.Schema`
578 Input pyarrow schema.
580 Returns
581 -------
582 index : `pandas.Index` or `pandas.MultiIndex`
583 Converted pandas index.
584 """
585 import pandas as pd
587 if b"pandas" in schema.metadata:
588 md = json.loads(schema.metadata[b"pandas"])
589 indexes = md["column_indexes"]
590 len_indexes = len(indexes)
591 else:
592 len_indexes = 0
594 if len_indexes <= 1:
595 return pd.Index(name for name in schema.names if not name.startswith("__"))
596 else:
597 raw_columns = _split_multi_index_column_names(len(indexes), schema.names)
598 return pd.MultiIndex.from_tuples(raw_columns, names=[f["name"] for f in indexes])
601def arrow_schema_to_column_list(schema: pa.Schema) -> list[str]:
602 """Convert an arrow schema to a list of string column names.
604 Parameters
605 ----------
606 schema : `pyarrow.Schema`
607 Input pyarrow schema.
609 Returns
610 -------
611 column_list : `list` [`str`]
612 Converted list of column names.
613 """
614 return [name for name in schema.names]
617class DataFrameSchema:
618 """Wrapper class for a schema for a pandas DataFrame.
620 Parameters
621 ----------
622 dataframe : `pandas.DataFrame`
623 Dataframe to turn into a schema.
624 """
626 def __init__(self, dataframe: pd.DataFrame) -> None:
627 self._schema = dataframe.loc[[False] * len(dataframe)]
629 @classmethod
630 def from_arrow(cls, schema: pa.Schema) -> DataFrameSchema:
631 """Convert an arrow schema into a `DataFrameSchema`.
633 Parameters
634 ----------
635 schema : `pyarrow.Schema`
636 The pyarrow schema to convert.
638 Returns
639 -------
640 dataframe_schema : `DataFrameSchema`
641 Converted dataframe schema.
642 """
643 empty_table = pa.Table.from_pylist([] * len(schema.names), schema=schema)
645 return cls(empty_table.to_pandas())
647 def to_arrow_schema(self) -> pa.Schema:
648 """Convert to an arrow schema.
650 Returns
651 -------
652 arrow_schema : `pyarrow.Schema`
653 Converted pyarrow schema.
654 """
655 arrow_table = pa.Table.from_pandas(self._schema)
657 return arrow_table.schema
659 def to_arrow_numpy_schema(self) -> ArrowNumpySchema:
660 """Convert to an `ArrowNumpySchema`.
662 Returns
663 -------
664 arrow_numpy_schema : `ArrowNumpySchema`
665 Converted arrow numpy schema.
666 """
667 return ArrowNumpySchema.from_arrow(self.to_arrow_schema())
669 def to_arrow_astropy_schema(self) -> ArrowAstropySchema:
670 """Convert to an ArrowAstropySchema.
672 Returns
673 -------
674 arrow_astropy_schema : `ArrowAstropySchema`
675 Converted arrow astropy schema.
676 """
677 return ArrowAstropySchema.from_arrow(self.to_arrow_schema())
679 @property
680 def schema(self) -> np.dtype:
681 return self._schema
683 def __repr__(self) -> str:
684 return repr(self._schema)
686 def __eq__(self, other: object) -> bool:
687 if not isinstance(other, DataFrameSchema):
688 return NotImplemented
690 return self._schema.equals(other._schema)
693class ArrowAstropySchema:
694 """Wrapper class for a schema for an astropy table.
696 Parameters
697 ----------
698 astropy_table : `astropy.table.Table`
699 Input astropy table.
700 """
702 def __init__(self, astropy_table: atable.Table) -> None:
703 self._schema = astropy_table[:0]
705 @classmethod
706 def from_arrow(cls, schema: pa.Schema) -> ArrowAstropySchema:
707 """Convert an arrow schema into a ArrowAstropySchema.
709 Parameters
710 ----------
711 schema : `pyarrow.Schema`
712 Input pyarrow schema.
714 Returns
715 -------
716 astropy_schema : `ArrowAstropySchema`
717 Converted arrow astropy schema.
718 """
719 import numpy as np
720 from astropy.table import Table
722 dtype = _schema_to_dtype_list(schema)
724 data = np.zeros(0, dtype=dtype)
725 astropy_table = Table(data=data)
727 metadata = schema.metadata if schema.metadata is not None else {}
729 _apply_astropy_metadata(astropy_table, metadata)
731 return cls(astropy_table)
733 def to_arrow_schema(self) -> pa.Schema:
734 """Convert to an arrow schema.
736 Returns
737 -------
738 arrow_schema : `pyarrow.Schema`
739 Converted pyarrow schema.
740 """
741 return astropy_to_arrow(self._schema).schema
743 def to_dataframe_schema(self) -> DataFrameSchema:
744 """Convert to a DataFrameSchema.
746 Returns
747 -------
748 dataframe_schema : `DataFrameSchema`
749 Converted dataframe schema.
750 """
751 return DataFrameSchema.from_arrow(astropy_to_arrow(self._schema).schema)
753 def to_arrow_numpy_schema(self) -> ArrowNumpySchema:
754 """Convert to an `ArrowNumpySchema`.
756 Returns
757 -------
758 arrow_numpy_schema : `ArrowNumpySchema`
759 Converted arrow numpy schema.
760 """
761 return ArrowNumpySchema.from_arrow(astropy_to_arrow(self._schema).schema)
763 @property
764 def schema(self) -> atable.Table:
765 return self._schema
767 def __repr__(self) -> str:
768 return repr(self._schema)
770 def __eq__(self, other: object) -> bool:
771 if not isinstance(other, ArrowAstropySchema):
772 return NotImplemented
774 # If this comparison passes then the two tables have the
775 # same column names.
776 if self._schema.dtype != other._schema.dtype:
777 return False
779 for name in self._schema.columns:
780 if not self._schema[name].unit == other._schema[name].unit:
781 return False
782 if not self._schema[name].description == other._schema[name].description:
783 return False
784 if not self._schema[name].format == other._schema[name].format:
785 return False
787 return True
790class ArrowNumpySchema:
791 """Wrapper class for a schema for a numpy ndarray.
793 Parameters
794 ----------
795 numpy_dtype : `numpy.dtype`
796 Numpy dtype to convert.
797 """
799 def __init__(self, numpy_dtype: np.dtype) -> None:
800 self._dtype = numpy_dtype
802 @classmethod
803 def from_arrow(cls, schema: pa.Schema) -> ArrowNumpySchema:
804 """Convert an arrow schema into an `ArrowNumpySchema`.
806 Parameters
807 ----------
808 schema : `pyarrow.Schema`
809 Pyarrow schema to convert.
811 Returns
812 -------
813 numpy_schema : `ArrowNumpySchema`
814 Converted arrow numpy schema.
815 """
816 import numpy as np
818 dtype = _schema_to_dtype_list(schema)
820 return cls(np.dtype(dtype))
822 def to_arrow_astropy_schema(self) -> ArrowAstropySchema:
823 """Convert to an `ArrowAstropySchema`.
825 Returns
826 -------
827 astropy_schema : `ArrowAstropySchema`
828 Converted arrow astropy schema.
829 """
830 import numpy as np
832 return ArrowAstropySchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema)
834 def to_dataframe_schema(self) -> DataFrameSchema:
835 """Convert to a `DataFrameSchema`.
837 Returns
838 -------
839 dataframe_schema : `DataFrameSchema`
840 Converted dataframe schema.
841 """
842 import numpy as np
844 return DataFrameSchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema)
846 def to_arrow_schema(self) -> pa.Schema:
847 """Convert to a `pyarrow.Schema`.
849 Returns
850 -------
851 arrow_schema : `pyarrow.Schema`
852 Converted pyarrow schema.
853 """
854 import numpy as np
856 return numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema
858 @property
859 def schema(self) -> np.dtype:
860 return self._dtype
862 def __repr__(self) -> str:
863 return repr(self._dtype)
865 def __eq__(self, other: object) -> bool:
866 if not isinstance(other, ArrowNumpySchema):
867 return NotImplemented
869 if not self._dtype == other._dtype:
870 return False
872 return True
875def _split_multi_index_column_names(n: int, names: Iterable[str]) -> list[Sequence[str]]:
876 """Split a string that represents a multi-index column.
878 PyArrow maps Pandas' multi-index column names (which are tuples in Python)
879 to flat strings on disk. This routine exists to reconstruct the original
880 tuple.
882 Parameters
883 ----------
884 n : `int`
885 Number of levels in the `pandas.MultiIndex` that is being
886 reconstructed.
887 names : `~collections.abc.Iterable` [`str`]
888 Strings to be split.
890 Returns
891 -------
892 column_names : `list` [`tuple` [`str`]]
893 A list of multi-index column name tuples.
894 """
895 column_names: list[Sequence[str]] = []
897 pattern = re.compile(r"\({}\)".format(", ".join(["'(.*)'"] * n)))
898 for name in names:
899 m = re.search(pattern, name)
900 if m is not None:
901 column_names.append(m.groups())
903 return column_names
906def _standardize_multi_index_columns(
907 pd_index: pd.MultiIndex,
908 columns: Any,
909 stringify: bool = True,
910) -> list[str | Sequence[Any]]:
911 """Transform a dictionary/iterable index from a multi-index column list
912 into a string directly understandable by PyArrow.
914 Parameters
915 ----------
916 pd_index : `pandas.MultiIndex`
917 Pandas multi-index.
918 columns : `list` [`tuple`] or `dict` [`str`, `str` or `list` [`str`]]
919 Columns to standardize.
920 stringify : `bool`, optional
921 Should the column names be stringified?
923 Returns
924 -------
925 names : `list` [`str`]
926 Stringified representation of a multi-index column name.
927 """
928 index_level_names = tuple(pd_index.names)
930 names: list[str | Sequence[Any]] = []
932 if isinstance(columns, list):
933 for requested in columns:
934 if not isinstance(requested, tuple):
935 raise ValueError(
936 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. "
937 f"Instead got a {get_full_type_name(requested)}."
938 )
939 if stringify:
940 names.append(str(requested))
941 else:
942 names.append(requested)
943 else:
944 if not isinstance(columns, collections.abc.Mapping):
945 raise ValueError(
946 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. "
947 f"Instead got a {get_full_type_name(columns)}."
948 )
949 if not set(index_level_names).issuperset(columns.keys()):
950 raise ValueError(
951 f"Cannot use dict with keys {set(columns.keys())} to select columns from {index_level_names}."
952 )
953 factors = [
954 ensure_iterable(columns.get(level, pd_index.levels[i]))
955 for i, level in enumerate(index_level_names)
956 ]
957 for requested in itertools.product(*factors):
958 for i, value in enumerate(requested):
959 if value not in pd_index.levels[i]:
960 raise ValueError(f"Unrecognized value {value!r} for index {index_level_names[i]!r}.")
961 if stringify:
962 names.append(str(requested))
963 else:
964 names.append(requested)
966 return names
969def _apply_astropy_metadata(astropy_table: atable.Table, metadata: dict) -> None:
970 """Apply any astropy metadata from the schema metadata.
972 Parameters
973 ----------
974 astropy_table : `astropy.table.Table`
975 Table to apply metadata.
976 metadata : `dict` [`bytes`]
977 Metadata dict.
978 """
979 from astropy.table import meta
981 meta_yaml = metadata.get(b"table_meta_yaml", None)
982 if meta_yaml:
983 meta_yaml = meta_yaml.decode("UTF8").split("\n")
984 meta_hdr = meta.get_header_from_yaml(meta_yaml)
986 # Set description, format, unit, meta from the column
987 # metadata that was serialized with the table.
988 header_cols = {x["name"]: x for x in meta_hdr["datatype"]}
989 for col in astropy_table.columns.values():
990 for attr in ("description", "format", "unit", "meta"):
991 if attr in header_cols[col.name]:
992 setattr(col, attr, header_cols[col.name][attr])
994 if "meta" in meta_hdr:
995 astropy_table.meta.update(meta_hdr["meta"])
998def _arrow_string_to_numpy_dtype(
999 schema: pa.Schema, name: str, numpy_column: np.ndarray | None = None, default_length: int = 10
1000) -> str:
1001 """Get the numpy dtype string associated with an arrow column.
1003 Parameters
1004 ----------
1005 schema : `pyarrow.Schema`
1006 Arrow table schema.
1007 name : `str`
1008 Column name.
1009 numpy_column : `numpy.ndarray`, optional
1010 Column to determine numpy string dtype.
1011 default_length : `int`, optional
1012 Default string length when not in metadata or can be inferred
1013 from column.
1015 Returns
1016 -------
1017 dtype_str : `str`
1018 Numpy dtype string.
1019 """
1020 # Special-case for string and binary columns
1021 md_name = f"lsst::arrow::len::{name}"
1022 strlen = default_length
1023 metadata = schema.metadata if schema.metadata is not None else {}
1024 if (encoded := md_name.encode("UTF-8")) in metadata:
1025 # String/bytes length from header.
1026 strlen = int(schema.metadata[encoded])
1027 elif numpy_column is not None:
1028 if len(numpy_column) > 0:
1029 strlen = max(len(row) for row in numpy_column)
1031 dtype = f"U{strlen}" if schema.field(name).type == pa.string() else f"|S{strlen}"
1033 return dtype
1036def _append_numpy_string_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None:
1037 """Append numpy string length keys to arrow metadata.
1039 All column types are handled, but the metadata is only modified for
1040 string and byte columns.
1042 Parameters
1043 ----------
1044 metadata : `dict` [`bytes`, `str`]
1045 Metadata dictionary; modified in place.
1046 name : `str`
1047 Column name.
1048 dtype : `np.dtype`
1049 Numpy dtype.
1050 """
1051 import numpy as np
1053 if dtype.type is np.str_:
1054 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize // 4)
1055 elif dtype.type is np.bytes_:
1056 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize)
1059def _append_numpy_multidim_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None:
1060 """Append numpy multi-dimensional shapes to arrow metadata.
1062 All column types are handled, but the metadata is only modified for
1063 multi-dimensional columns.
1065 Parameters
1066 ----------
1067 metadata : `dict` [`bytes`, `str`]
1068 Metadata dictionary; modified in place.
1069 name : `str`
1070 Column name.
1071 dtype : `np.dtype`
1072 Numpy dtype.
1073 """
1074 if len(dtype.shape) > 1:
1075 metadata[f"lsst::arrow::shape::{name}".encode()] = str(dtype.shape)
1078def _multidim_shape_from_metadata(metadata: dict[bytes, bytes], list_size: int, name: str) -> tuple[int, ...]:
1079 """Retrieve the shape from the metadata, if available.
1081 Parameters
1082 ----------
1083 metadata : `dict` [`bytes`, `bytes`]
1084 Metadata dictionary.
1085 list_size : `int`
1086 Size of the list datatype.
1087 name : `str`
1088 Column name.
1090 Returns
1091 -------
1092 shape : `tuple` [`int`]
1093 Shape associated with the column.
1095 Raises
1096 ------
1097 RuntimeError
1098 Raised if metadata is found but has incorrect format.
1099 """
1100 md_name = f"lsst::arrow::shape::{name}"
1101 if (encoded := md_name.encode("UTF-8")) in metadata:
1102 groups = re.search(r"\((.*)\)", metadata[encoded].decode("UTF-8"))
1103 if groups is None:
1104 raise RuntimeError("Illegal value found in metadata.")
1105 shape = tuple(int(x) for x in groups[1].split(",") if x != "")
1106 else:
1107 shape = (list_size,)
1109 return shape
1112def _schema_to_dtype_list(schema: pa.Schema) -> list[tuple[str, tuple[Any] | str]]:
1113 """Convert a pyarrow schema to a numpy dtype.
1115 Parameters
1116 ----------
1117 schema : `pyarrow.Schema`
1118 Input pyarrow schema.
1120 Returns
1121 -------
1122 dtype_list: `list` [`tuple`]
1123 A list with name, type pairs.
1124 """
1125 metadata = schema.metadata if schema.metadata is not None else {}
1127 dtype: list[Any] = []
1128 for name in schema.names:
1129 t = schema.field(name).type
1130 if isinstance(t, pa.FixedSizeListType):
1131 shape = _multidim_shape_from_metadata(metadata, t.list_size, name)
1132 dtype.append((name, (t.value_type.to_pandas_dtype(), shape)))
1133 elif t not in (pa.string(), pa.binary()):
1134 dtype.append((name, t.to_pandas_dtype()))
1135 else:
1136 dtype.append((name, _arrow_string_to_numpy_dtype(schema, name)))
1138 return dtype
1141def _numpy_dtype_to_arrow_types(dtype: np.dtype) -> list[Any]:
1142 """Convert a numpy dtype to a list of arrow types.
1144 Parameters
1145 ----------
1146 dtype : `numpy.dtype`
1147 Numpy dtype to convert.
1149 Returns
1150 -------
1151 type_list : `list` [`object`]
1152 Converted list of arrow types.
1153 """
1154 from math import prod
1156 import numpy as np
1158 type_list: list[Any] = []
1159 if dtype.names is None:
1160 return type_list
1162 for name in dtype.names:
1163 dt = dtype[name]
1164 arrow_type: Any
1165 if len(dt.shape) > 0:
1166 arrow_type = pa.list_(
1167 pa.from_numpy_dtype(cast(tuple[np.dtype, tuple[int, ...]], dt.subdtype)[0].type),
1168 prod(dt.shape),
1169 )
1170 else:
1171 arrow_type = pa.from_numpy_dtype(dt.type)
1172 type_list.append((name, arrow_type))
1174 return type_list
1177def _numpy_dict_to_dtype(numpy_dict: dict[str, np.ndarray]) -> tuple[np.dtype, int]:
1178 """Extract equivalent table dtype from dict of numpy arrays.
1180 Parameters
1181 ----------
1182 numpy_dict : `dict` [`str`, `numpy.ndarray`]
1183 Dict with keys as the column names, values as the arrays.
1185 Returns
1186 -------
1187 dtype : `numpy.dtype`
1188 dtype of equivalent table.
1189 rowcount : `int`
1190 Number of rows in the table.
1192 Raises
1193 ------
1194 ValueError if columns in numpy_dict have unequal numbers of rows.
1195 """
1196 import numpy as np
1198 dtype_list = []
1199 rowcount = 0
1200 for name, col in numpy_dict.items():
1201 if rowcount == 0:
1202 rowcount = len(col)
1203 if len(col) != rowcount:
1204 raise ValueError(f"Column {name} has a different number of rows.")
1205 if len(col.shape) == 1:
1206 dtype_list.append((name, col.dtype))
1207 else:
1208 dtype_list.append((name, (col.dtype, col.shape[1:])))
1209 dtype = np.dtype(dtype_list)
1211 return (dtype, rowcount)
1214def _numpy_style_arrays_to_arrow_arrays(
1215 dtype: np.dtype,
1216 rowcount: int,
1217 np_style_arrays: dict[str, np.ndarray] | np.ndarray | atable.Table,
1218 schema: pa.Schema,
1219) -> list[pa.Array]:
1220 """Convert numpy-style arrays to arrow arrays.
1222 Parameters
1223 ----------
1224 dtype : `numpy.dtype`
1225 Numpy dtype of input table/arrays.
1226 rowcount : `int`
1227 Number of rows in input table/arrays.
1228 np_style_arrays : `dict` [`str`, `np.ndarray`] or `np.ndarray`
1229 or `astropy.table.Table`
1230 Arrays to convert to arrow.
1231 schema : `pyarrow.Schema`
1232 Schema of arrow table.
1234 Returns
1235 -------
1236 arrow_arrays : `list` [`pyarrow.Array`]
1237 List of converted pyarrow arrays.
1238 """
1239 import numpy as np
1241 arrow_arrays: list[pa.Array] = []
1242 if dtype.names is None:
1243 return arrow_arrays
1245 for name in dtype.names:
1246 dt = dtype[name]
1247 val: Any
1248 if len(dt.shape) > 0:
1249 if rowcount > 0:
1250 val = np.split(np_style_arrays[name].ravel(), rowcount)
1251 else:
1252 val = []
1253 else:
1254 val = np_style_arrays[name]
1256 try:
1257 arrow_arrays.append(pa.array(val, type=schema.field(name).type))
1258 except pa.ArrowNotImplementedError as err:
1259 # Check if val is big-endian.
1260 if (np.little_endian and val.dtype.byteorder == ">") or (
1261 not np.little_endian and val.dtype.byteorder == "="
1262 ):
1263 # We need to convert the array to little-endian.
1264 val2 = val.byteswap()
1265 val2.dtype = val2.dtype.newbyteorder("<")
1266 arrow_arrays.append(pa.array(val2, type=schema.field(name).type))
1267 else:
1268 # This failed for some other reason so raise the exception.
1269 raise err
1271 return arrow_arrays
1274def compute_row_group_size(schema: pa.Schema, target_size: int = TARGET_ROW_GROUP_BYTES) -> int:
1275 """Compute approximate row group size for a given arrow schema.
1277 Given a schema, this routine will compute the number of rows in a row group
1278 that targets the persisted size on disk (or smaller). The exact size on
1279 disk depends on the compression settings and ratios; typical binary data
1280 tables will have around 15-20% compression with the pyarrow default
1281 ``snappy`` compression algorithm.
1283 Parameters
1284 ----------
1285 schema : `pyarrow.Schema`
1286 Arrow table schema.
1287 target_size : `int`, optional
1288 The target size (in bytes).
1290 Returns
1291 -------
1292 row_group_size : `int`
1293 Number of rows per row group to hit the target size.
1294 """
1295 bit_width = 0
1297 metadata = schema.metadata if schema.metadata is not None else {}
1299 for name in schema.names:
1300 t = schema.field(name).type
1302 if t in (pa.string(), pa.binary()):
1303 md_name = f"lsst::arrow::len::{name}"
1305 if (encoded := md_name.encode("UTF-8")) in metadata:
1306 # String/bytes length from header.
1307 strlen = int(schema.metadata[encoded])
1308 else:
1309 # We don't know the string width, so guess something.
1310 strlen = 10
1312 # Assuming UTF-8 encoding, and very few wide characters.
1313 t_width = 8 * strlen
1314 elif isinstance(t, pa.FixedSizeListType):
1315 if t.value_type == pa.null():
1316 t_width = 0
1317 else:
1318 t_width = t.list_size * t.value_type.bit_width
1319 elif t == pa.null():
1320 t_width = 0
1321 elif isinstance(t, pa.ListType):
1322 if t.value_type == pa.null():
1323 t_width = 0
1324 else:
1325 # This is a variable length list, just choose
1326 # something arbitrary.
1327 t_width = 10 * t.value_type.bit_width
1328 else:
1329 t_width = t.bit_width
1331 bit_width += t_width
1333 # Insist it is at least 1 byte wide to avoid any divide-by-zero errors.
1334 if bit_width < 8:
1335 bit_width = 8
1337 byte_width = bit_width // 8
1339 return target_size // byte_width