Coverage for python/lsst/daf/butler/formatters/parquet.py: 13%
469 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-05 10:00 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-05 10:00 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = (
31 "ParquetFormatter",
32 "arrow_to_pandas",
33 "arrow_to_astropy",
34 "arrow_to_numpy",
35 "arrow_to_numpy_dict",
36 "pandas_to_arrow",
37 "pandas_to_astropy",
38 "astropy_to_arrow",
39 "numpy_to_arrow",
40 "numpy_to_astropy",
41 "numpy_dict_to_arrow",
42 "arrow_schema_to_pandas_index",
43 "DataFrameSchema",
44 "ArrowAstropySchema",
45 "ArrowNumpySchema",
46 "compute_row_group_size",
47)
49import collections.abc
50import itertools
51import json
52import re
53from collections.abc import Iterable, Sequence
54from typing import TYPE_CHECKING, Any, cast
56import pyarrow as pa
57import pyarrow.parquet as pq
58from lsst.daf.butler import Formatter
59from lsst.utils.introspection import get_full_type_name
60from lsst.utils.iteration import ensure_iterable
62if TYPE_CHECKING:
63 import astropy.table as atable
64 import numpy as np
65 import pandas as pd
67TARGET_ROW_GROUP_BYTES = 1_000_000_000
70class ParquetFormatter(Formatter):
71 """Interface for reading and writing Arrow Table objects to and from
72 Parquet files.
73 """
75 extension = ".parq"
77 def read(self, component: str | None = None) -> Any:
78 # Docstring inherited from Formatter.read.
79 schema = pq.read_schema(self.fileDescriptor.location.path)
81 schema_names = ["ArrowSchema", "DataFrameSchema", "ArrowAstropySchema", "ArrowNumpySchema"]
83 if component in ("columns", "schema") or self.fileDescriptor.readStorageClass.name in schema_names:
84 # The schema will be translated to column format
85 # depending on the input type.
86 return schema
87 elif component == "rowcount":
88 # Get the rowcount from the metadata if possible, otherwise count.
89 if b"lsst::arrow::rowcount" in schema.metadata:
90 return int(schema.metadata[b"lsst::arrow::rowcount"])
92 temp_table = pq.read_table(
93 self.fileDescriptor.location.path,
94 columns=[schema.names[0]],
95 use_threads=False,
96 use_pandas_metadata=False,
97 )
99 return len(temp_table[schema.names[0]])
101 par_columns = None
102 if self.fileDescriptor.parameters:
103 par_columns = self.fileDescriptor.parameters.pop("columns", None)
104 if par_columns:
105 has_pandas_multi_index = False
106 if b"pandas" in schema.metadata:
107 md = json.loads(schema.metadata[b"pandas"])
108 if len(md["column_indexes"]) > 1:
109 has_pandas_multi_index = True
111 if not has_pandas_multi_index:
112 # Ensure uniqueness, keeping order.
113 par_columns = list(dict.fromkeys(ensure_iterable(par_columns)))
114 file_columns = [name for name in schema.names if not name.startswith("__")]
116 for par_column in par_columns:
117 if par_column not in file_columns:
118 raise ValueError(
119 f"Column {par_column} specified in parameters not available in parquet file."
120 )
121 else:
122 par_columns = _standardize_multi_index_columns(
123 arrow_schema_to_pandas_index(schema),
124 par_columns,
125 )
127 if len(self.fileDescriptor.parameters):
128 raise ValueError(
129 f"Unsupported parameters {self.fileDescriptor.parameters} in ArrowTable read."
130 )
132 metadata = schema.metadata if schema.metadata is not None else {}
133 arrow_table = pq.read_table(
134 self.fileDescriptor.location.path,
135 columns=par_columns,
136 use_threads=False,
137 use_pandas_metadata=(b"pandas" in metadata),
138 )
140 return arrow_table
142 def write(self, inMemoryDataset: Any) -> None:
143 import numpy as np
144 from astropy.table import Table as astropyTable
146 location = self.makeUpdatedLocation(self.fileDescriptor.location)
148 arrow_table = None
149 if isinstance(inMemoryDataset, pa.Table):
150 # This will be the most likely match.
151 arrow_table = inMemoryDataset
152 elif isinstance(inMemoryDataset, astropyTable):
153 arrow_table = astropy_to_arrow(inMemoryDataset)
154 elif isinstance(inMemoryDataset, np.ndarray):
155 arrow_table = numpy_to_arrow(inMemoryDataset)
156 elif isinstance(inMemoryDataset, dict):
157 try:
158 arrow_table = numpy_dict_to_arrow(inMemoryDataset)
159 except (TypeError, AttributeError) as e:
160 raise ValueError(
161 "Input dict for inMemoryDataset does not appear to be a dict of numpy arrays."
162 ) from e
163 elif isinstance(inMemoryDataset, pa.Schema):
164 pq.write_metadata(inMemoryDataset, location.path)
165 return
166 else:
167 if hasattr(inMemoryDataset, "to_parquet"):
168 # This may be a pandas DataFrame
169 try:
170 import pandas as pd
171 except ImportError:
172 pd = None
174 if pd is not None and isinstance(inMemoryDataset, pd.DataFrame):
175 arrow_table = pandas_to_arrow(inMemoryDataset)
177 if arrow_table is None:
178 raise ValueError(
179 f"Unsupported type {get_full_type_name(inMemoryDataset)} of "
180 "inMemoryDataset for ParquetFormatter."
181 )
183 row_group_size = compute_row_group_size(arrow_table.schema)
185 pq.write_table(arrow_table, location.path, row_group_size=row_group_size)
188def arrow_to_pandas(arrow_table: pa.Table) -> pd.DataFrame:
189 """Convert a pyarrow table to a pandas DataFrame.
191 Parameters
192 ----------
193 arrow_table : `pyarrow.Table`
194 Input arrow table to convert. If the table has ``pandas`` metadata
195 in the schema it will be used in the construction of the
196 ``DataFrame``.
198 Returns
199 -------
200 dataframe : `pandas.DataFrame`
201 Converted pandas dataframe.
202 """
203 return arrow_table.to_pandas(use_threads=False, integer_object_nulls=True)
206def arrow_to_astropy(arrow_table: pa.Table) -> atable.Table:
207 """Convert a pyarrow table to an `astropy.Table`.
209 Parameters
210 ----------
211 arrow_table : `pyarrow.Table`
212 Input arrow table to convert. If the table has astropy unit
213 metadata in the schema it will be used in the construction
214 of the ``astropy.Table``.
216 Returns
217 -------
218 table : `astropy.Table`
219 Converted astropy table.
220 """
221 from astropy.table import Table
223 astropy_table = Table(arrow_to_numpy_dict(arrow_table))
225 _apply_astropy_metadata(astropy_table, arrow_table.schema)
227 return astropy_table
230def arrow_to_numpy(arrow_table: pa.Table) -> np.ndarray:
231 """Convert a pyarrow table to a structured numpy array.
233 Parameters
234 ----------
235 arrow_table : `pyarrow.Table`
236 Input arrow table.
238 Returns
239 -------
240 array : `numpy.ndarray` (N,)
241 Numpy array table with N rows and the same column names
242 as the input arrow table.
243 """
244 import numpy as np
246 numpy_dict = arrow_to_numpy_dict(arrow_table)
248 dtype = []
249 for name, col in numpy_dict.items():
250 if len(shape := numpy_dict[name].shape) <= 1:
251 dtype.append((name, col.dtype))
252 else:
253 dtype.append((name, (col.dtype, shape[1:])))
255 array = np.rec.fromarrays(numpy_dict.values(), dtype=dtype)
257 return array
260def arrow_to_numpy_dict(arrow_table: pa.Table) -> dict[str, np.ndarray]:
261 """Convert a pyarrow table to a dict of numpy arrays.
263 Parameters
264 ----------
265 arrow_table : `pyarrow.Table`
266 Input arrow table.
268 Returns
269 -------
270 numpy_dict : `dict` [`str`, `numpy.ndarray`]
271 Dict with keys as the column names, values as the arrays.
272 """
273 import numpy as np
275 schema = arrow_table.schema
276 metadata = schema.metadata if schema.metadata is not None else {}
278 numpy_dict = {}
280 for name in schema.names:
281 t = schema.field(name).type
283 if arrow_table[name].null_count == 0:
284 # Regular non-masked column
285 col = arrow_table[name].to_numpy()
286 else:
287 # For a masked column, we need to ask arrow to fill the null
288 # values with an appropriately typed value before conversion.
289 # Then we apply the mask to get a masked array of the correct type.
290 null_value: Any
291 match t:
292 case t if t in (pa.float64(), pa.float32(), pa.float16()):
293 null_value = np.nan
294 case t if t in (pa.int64(), pa.int32(), pa.int16(), pa.int8()):
295 null_value = -1
296 case t if t in (pa.bool_(),):
297 null_value = True
298 case t if t in (pa.string(), pa.binary()):
299 null_value = ""
300 case _:
301 # This is the fallback for unsigned ints in particular.
302 null_value = 0
304 col = np.ma.masked_array(
305 data=arrow_table[name].fill_null(null_value).to_numpy(),
306 mask=arrow_table[name].is_null().to_numpy(),
307 fill_value=null_value,
308 )
310 if t in (pa.string(), pa.binary()):
311 col = col.astype(_arrow_string_to_numpy_dtype(schema, name, col))
312 elif isinstance(t, pa.FixedSizeListType):
313 if len(col) > 0:
314 col = np.stack(col)
315 else:
316 # this is an empty column, and needs to be coerced to type.
317 col = col.astype(t.value_type.to_pandas_dtype())
319 shape = _multidim_shape_from_metadata(metadata, t.list_size, name)
320 col = col.reshape((len(arrow_table), *shape))
322 numpy_dict[name] = col
324 return numpy_dict
327def _numpy_dict_to_numpy(numpy_dict: dict[str, np.ndarray]) -> np.ndarray:
328 """Convert a dict of numpy arrays to a structured numpy array.
330 Parameters
331 ----------
332 numpy_dict : `dict` [`str`, `numpy.ndarray`]
333 Dict with keys as the column names, values as the arrays.
335 Returns
336 -------
337 array : `numpy.ndarray` (N,)
338 Numpy array table with N rows and columns names from the dict keys.
339 """
340 return arrow_to_numpy(numpy_dict_to_arrow(numpy_dict))
343def _numpy_to_numpy_dict(np_array: np.ndarray) -> dict[str, np.ndarray]:
344 """Convert a structured numpy array to a dict of numpy arrays.
346 Parameters
347 ----------
348 np_array : `numpy.ndarray`
349 Input numpy array with multiple fields.
351 Returns
352 -------
353 numpy_dict : `dict` [`str`, `numpy.ndarray`]
354 Dict with keys as the column names, values as the arrays.
355 """
356 return arrow_to_numpy_dict(numpy_to_arrow(np_array))
359def numpy_to_arrow(np_array: np.ndarray) -> pa.Table:
360 """Convert a numpy array table to an arrow table.
362 Parameters
363 ----------
364 np_array : `numpy.ndarray`
365 Input numpy array with multiple fields.
367 Returns
368 -------
369 arrow_table : `pyarrow.Table`
370 Converted arrow table.
371 """
372 type_list = _numpy_dtype_to_arrow_types(np_array.dtype)
374 md = {}
375 md[b"lsst::arrow::rowcount"] = str(len(np_array))
377 for name in np_array.dtype.names:
378 _append_numpy_string_metadata(md, name, np_array.dtype[name])
379 _append_numpy_multidim_metadata(md, name, np_array.dtype[name])
381 schema = pa.schema(type_list, metadata=md)
383 arrays = _numpy_style_arrays_to_arrow_arrays(
384 np_array.dtype,
385 len(np_array),
386 np_array,
387 schema,
388 )
390 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
392 return arrow_table
395def numpy_dict_to_arrow(numpy_dict: dict[str, np.ndarray]) -> pa.Table:
396 """Convert a dict of numpy arrays to an arrow table.
398 Parameters
399 ----------
400 numpy_dict : `dict` [`str`, `numpy.ndarray`]
401 Dict with keys as the column names, values as the arrays.
403 Returns
404 -------
405 arrow_table : `pyarrow.Table`
406 Converted arrow table.
408 Raises
409 ------
410 ValueError if columns in numpy_dict have unequal numbers of rows.
411 """
412 dtype, rowcount = _numpy_dict_to_dtype(numpy_dict)
413 type_list = _numpy_dtype_to_arrow_types(dtype)
415 md = {}
416 md[b"lsst::arrow::rowcount"] = str(rowcount)
418 if dtype.names is not None:
419 for name in dtype.names:
420 _append_numpy_string_metadata(md, name, dtype[name])
421 _append_numpy_multidim_metadata(md, name, dtype[name])
423 schema = pa.schema(type_list, metadata=md)
425 arrays = _numpy_style_arrays_to_arrow_arrays(
426 dtype,
427 rowcount,
428 numpy_dict,
429 schema,
430 )
432 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
434 return arrow_table
437def astropy_to_arrow(astropy_table: atable.Table) -> pa.Table:
438 """Convert an astropy table to an arrow table.
440 Parameters
441 ----------
442 astropy_table : `astropy.Table`
443 Input astropy table.
445 Returns
446 -------
447 arrow_table : `pyarrow.Table`
448 Converted arrow table.
449 """
450 from astropy.table import meta
452 type_list = _numpy_dtype_to_arrow_types(astropy_table.dtype)
454 md = {}
455 md[b"lsst::arrow::rowcount"] = str(len(astropy_table))
457 for name in astropy_table.dtype.names:
458 _append_numpy_string_metadata(md, name, astropy_table.dtype[name])
459 _append_numpy_multidim_metadata(md, name, astropy_table.dtype[name])
461 meta_yaml = meta.get_yaml_from_table(astropy_table)
462 meta_yaml_str = "\n".join(meta_yaml)
463 md[b"table_meta_yaml"] = meta_yaml_str
465 # Convert type list to fields with metadata.
466 fields = []
467 for name, pa_type in type_list:
468 field_metadata = {}
469 if description := astropy_table[name].description:
470 field_metadata["description"] = description
471 if unit := astropy_table[name].unit:
472 field_metadata["unit"] = str(unit)
473 fields.append(
474 pa.field(
475 name,
476 pa_type,
477 metadata=field_metadata,
478 )
479 )
481 schema = pa.schema(fields, metadata=md)
483 arrays = _numpy_style_arrays_to_arrow_arrays(
484 astropy_table.dtype,
485 len(astropy_table),
486 astropy_table,
487 schema,
488 )
490 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
492 return arrow_table
495def _astropy_to_numpy_dict(astropy_table: atable.Table) -> dict[str, np.ndarray]:
496 """Convert an astropy table to an arrow table.
498 Parameters
499 ----------
500 astropy_table : `astropy.Table`
501 Input astropy table.
503 Returns
504 -------
505 numpy_dict : `dict` [`str`, `numpy.ndarray`]
506 Dict with keys as the column names, values as the arrays.
507 """
508 return arrow_to_numpy_dict(astropy_to_arrow(astropy_table))
511def pandas_to_arrow(dataframe: pd.DataFrame, default_length: int = 10) -> pa.Table:
512 """Convert a pandas dataframe to an arrow table.
514 Parameters
515 ----------
516 dataframe : `pandas.DataFrame`
517 Input pandas dataframe.
518 default_length : `int`, optional
519 Default string length when not in metadata or can be inferred
520 from column.
522 Returns
523 -------
524 arrow_table : `pyarrow.Table`
525 Converted arrow table.
526 """
527 arrow_table = pa.Table.from_pandas(dataframe)
529 # Update the metadata
530 md = arrow_table.schema.metadata
532 md[b"lsst::arrow::rowcount"] = str(arrow_table.num_rows)
534 # We loop through the arrow table columns because the datatypes have
535 # been checked and converted from pandas objects.
536 for name in arrow_table.column_names:
537 if not name.startswith("__") and arrow_table[name].type == pa.string():
538 if len(arrow_table[name]) > 0:
539 strlen = max(len(row.as_py()) for row in arrow_table[name] if row.is_valid)
540 else:
541 strlen = default_length
542 md[f"lsst::arrow::len::{name}".encode()] = str(strlen)
544 arrow_table = arrow_table.replace_schema_metadata(md)
546 return arrow_table
549def pandas_to_astropy(dataframe: pd.DataFrame) -> atable.Table:
550 """Convert a pandas dataframe to an astropy table, preserving indexes.
552 Parameters
553 ----------
554 dataframe : `pandas.DataFrame`
555 Input pandas dataframe.
557 Returns
558 -------
559 astropy_table : `astropy.table.Table`
560 Converted astropy table.
561 """
562 import pandas as pd
563 from astropy.table import Table
565 if isinstance(dataframe.columns, pd.MultiIndex):
566 raise ValueError("Cannot convert a multi-index dataframe to an astropy table.")
568 return Table.from_pandas(dataframe, index=True)
571def _pandas_to_numpy_dict(dataframe: pd.DataFrame) -> dict[str, np.ndarray]:
572 """Convert a pandas dataframe to an dict of numpy arrays.
574 Parameters
575 ----------
576 dataframe : `pandas.DataFrame`
577 Input pandas dataframe.
579 Returns
580 -------
581 numpy_dict : `dict` [`str`, `numpy.ndarray`]
582 Dict with keys as the column names, values as the arrays.
583 """
584 return arrow_to_numpy_dict(pandas_to_arrow(dataframe))
587def numpy_to_astropy(np_array: np.ndarray) -> atable.Table:
588 """Convert a numpy table to an astropy table.
590 Parameters
591 ----------
592 np_array : `numpy.ndarray`
593 Input numpy array with multiple fields.
595 Returns
596 -------
597 astropy_table : `astropy.table.Table`
598 Converted astropy table.
599 """
600 from astropy.table import Table
602 return Table(data=np_array, copy=False)
605def arrow_schema_to_pandas_index(schema: pa.Schema) -> pd.Index | pd.MultiIndex:
606 """Convert an arrow schema to a pandas index/multiindex.
608 Parameters
609 ----------
610 schema : `pyarrow.Schema`
611 Input pyarrow schema.
613 Returns
614 -------
615 index : `pandas.Index` or `pandas.MultiIndex`
616 Converted pandas index.
617 """
618 import pandas as pd
620 if b"pandas" in schema.metadata:
621 md = json.loads(schema.metadata[b"pandas"])
622 indexes = md["column_indexes"]
623 len_indexes = len(indexes)
624 else:
625 len_indexes = 0
627 if len_indexes <= 1:
628 return pd.Index(name for name in schema.names if not name.startswith("__"))
629 else:
630 raw_columns = _split_multi_index_column_names(len(indexes), schema.names)
631 return pd.MultiIndex.from_tuples(raw_columns, names=[f["name"] for f in indexes])
634def arrow_schema_to_column_list(schema: pa.Schema) -> list[str]:
635 """Convert an arrow schema to a list of string column names.
637 Parameters
638 ----------
639 schema : `pyarrow.Schema`
640 Input pyarrow schema.
642 Returns
643 -------
644 column_list : `list` [`str`]
645 Converted list of column names.
646 """
647 return list(schema.names)
650class DataFrameSchema:
651 """Wrapper class for a schema for a pandas DataFrame.
653 Parameters
654 ----------
655 dataframe : `pandas.DataFrame`
656 Dataframe to turn into a schema.
657 """
659 def __init__(self, dataframe: pd.DataFrame) -> None:
660 self._schema = dataframe.loc[[False] * len(dataframe)]
662 @classmethod
663 def from_arrow(cls, schema: pa.Schema) -> DataFrameSchema:
664 """Convert an arrow schema into a `DataFrameSchema`.
666 Parameters
667 ----------
668 schema : `pyarrow.Schema`
669 The pyarrow schema to convert.
671 Returns
672 -------
673 dataframe_schema : `DataFrameSchema`
674 Converted dataframe schema.
675 """
676 empty_table = pa.Table.from_pylist([] * len(schema.names), schema=schema)
678 return cls(empty_table.to_pandas())
680 def to_arrow_schema(self) -> pa.Schema:
681 """Convert to an arrow schema.
683 Returns
684 -------
685 arrow_schema : `pyarrow.Schema`
686 Converted pyarrow schema.
687 """
688 arrow_table = pa.Table.from_pandas(self._schema)
690 return arrow_table.schema
692 def to_arrow_numpy_schema(self) -> ArrowNumpySchema:
693 """Convert to an `ArrowNumpySchema`.
695 Returns
696 -------
697 arrow_numpy_schema : `ArrowNumpySchema`
698 Converted arrow numpy schema.
699 """
700 return ArrowNumpySchema.from_arrow(self.to_arrow_schema())
702 def to_arrow_astropy_schema(self) -> ArrowAstropySchema:
703 """Convert to an ArrowAstropySchema.
705 Returns
706 -------
707 arrow_astropy_schema : `ArrowAstropySchema`
708 Converted arrow astropy schema.
709 """
710 return ArrowAstropySchema.from_arrow(self.to_arrow_schema())
712 @property
713 def schema(self) -> np.dtype:
714 return self._schema
716 def __repr__(self) -> str:
717 return repr(self._schema)
719 def __eq__(self, other: object) -> bool:
720 if not isinstance(other, DataFrameSchema):
721 return NotImplemented
723 return self._schema.equals(other._schema)
726class ArrowAstropySchema:
727 """Wrapper class for a schema for an astropy table.
729 Parameters
730 ----------
731 astropy_table : `astropy.table.Table`
732 Input astropy table.
733 """
735 def __init__(self, astropy_table: atable.Table) -> None:
736 self._schema = astropy_table[:0]
738 @classmethod
739 def from_arrow(cls, schema: pa.Schema) -> ArrowAstropySchema:
740 """Convert an arrow schema into a ArrowAstropySchema.
742 Parameters
743 ----------
744 schema : `pyarrow.Schema`
745 Input pyarrow schema.
747 Returns
748 -------
749 astropy_schema : `ArrowAstropySchema`
750 Converted arrow astropy schema.
751 """
752 import numpy as np
753 from astropy.table import Table
755 dtype = _schema_to_dtype_list(schema)
757 data = np.zeros(0, dtype=dtype)
758 astropy_table = Table(data=data)
760 _apply_astropy_metadata(astropy_table, schema)
762 return cls(astropy_table)
764 def to_arrow_schema(self) -> pa.Schema:
765 """Convert to an arrow schema.
767 Returns
768 -------
769 arrow_schema : `pyarrow.Schema`
770 Converted pyarrow schema.
771 """
772 return astropy_to_arrow(self._schema).schema
774 def to_dataframe_schema(self) -> DataFrameSchema:
775 """Convert to a DataFrameSchema.
777 Returns
778 -------
779 dataframe_schema : `DataFrameSchema`
780 Converted dataframe schema.
781 """
782 return DataFrameSchema.from_arrow(astropy_to_arrow(self._schema).schema)
784 def to_arrow_numpy_schema(self) -> ArrowNumpySchema:
785 """Convert to an `ArrowNumpySchema`.
787 Returns
788 -------
789 arrow_numpy_schema : `ArrowNumpySchema`
790 Converted arrow numpy schema.
791 """
792 return ArrowNumpySchema.from_arrow(astropy_to_arrow(self._schema).schema)
794 @property
795 def schema(self) -> atable.Table:
796 return self._schema
798 def __repr__(self) -> str:
799 return repr(self._schema)
801 def __eq__(self, other: object) -> bool:
802 if not isinstance(other, ArrowAstropySchema):
803 return NotImplemented
805 # If this comparison passes then the two tables have the
806 # same column names.
807 if self._schema.dtype != other._schema.dtype:
808 return False
810 for name in self._schema.columns:
811 if not self._schema[name].unit == other._schema[name].unit:
812 return False
813 if not self._schema[name].description == other._schema[name].description:
814 return False
815 if not self._schema[name].format == other._schema[name].format:
816 return False
818 return True
821class ArrowNumpySchema:
822 """Wrapper class for a schema for a numpy ndarray.
824 Parameters
825 ----------
826 numpy_dtype : `numpy.dtype`
827 Numpy dtype to convert.
828 """
830 def __init__(self, numpy_dtype: np.dtype) -> None:
831 self._dtype = numpy_dtype
833 @classmethod
834 def from_arrow(cls, schema: pa.Schema) -> ArrowNumpySchema:
835 """Convert an arrow schema into an `ArrowNumpySchema`.
837 Parameters
838 ----------
839 schema : `pyarrow.Schema`
840 Pyarrow schema to convert.
842 Returns
843 -------
844 numpy_schema : `ArrowNumpySchema`
845 Converted arrow numpy schema.
846 """
847 import numpy as np
849 dtype = _schema_to_dtype_list(schema)
851 return cls(np.dtype(dtype))
853 def to_arrow_astropy_schema(self) -> ArrowAstropySchema:
854 """Convert to an `ArrowAstropySchema`.
856 Returns
857 -------
858 astropy_schema : `ArrowAstropySchema`
859 Converted arrow astropy schema.
860 """
861 import numpy as np
863 return ArrowAstropySchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema)
865 def to_dataframe_schema(self) -> DataFrameSchema:
866 """Convert to a `DataFrameSchema`.
868 Returns
869 -------
870 dataframe_schema : `DataFrameSchema`
871 Converted dataframe schema.
872 """
873 import numpy as np
875 return DataFrameSchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema)
877 def to_arrow_schema(self) -> pa.Schema:
878 """Convert to a `pyarrow.Schema`.
880 Returns
881 -------
882 arrow_schema : `pyarrow.Schema`
883 Converted pyarrow schema.
884 """
885 import numpy as np
887 return numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema
889 @property
890 def schema(self) -> np.dtype:
891 return self._dtype
893 def __repr__(self) -> str:
894 return repr(self._dtype)
896 def __eq__(self, other: object) -> bool:
897 if not isinstance(other, ArrowNumpySchema):
898 return NotImplemented
900 if not self._dtype == other._dtype:
901 return False
903 return True
906def _split_multi_index_column_names(n: int, names: Iterable[str]) -> list[Sequence[str]]:
907 """Split a string that represents a multi-index column.
909 PyArrow maps Pandas' multi-index column names (which are tuples in Python)
910 to flat strings on disk. This routine exists to reconstruct the original
911 tuple.
913 Parameters
914 ----------
915 n : `int`
916 Number of levels in the `pandas.MultiIndex` that is being
917 reconstructed.
918 names : `~collections.abc.Iterable` [`str`]
919 Strings to be split.
921 Returns
922 -------
923 column_names : `list` [`tuple` [`str`]]
924 A list of multi-index column name tuples.
925 """
926 column_names: list[Sequence[str]] = []
928 pattern = re.compile(r"\({}\)".format(", ".join(["'(.*)'"] * n)))
929 for name in names:
930 m = re.search(pattern, name)
931 if m is not None:
932 column_names.append(m.groups())
934 return column_names
937def _standardize_multi_index_columns(
938 pd_index: pd.MultiIndex,
939 columns: Any,
940 stringify: bool = True,
941) -> list[str | Sequence[Any]]:
942 """Transform a dictionary/iterable index from a multi-index column list
943 into a string directly understandable by PyArrow.
945 Parameters
946 ----------
947 pd_index : `pandas.MultiIndex`
948 Pandas multi-index.
949 columns : `list` [`tuple`] or `dict` [`str`, `str` or `list` [`str`]]
950 Columns to standardize.
951 stringify : `bool`, optional
952 Should the column names be stringified?
954 Returns
955 -------
956 names : `list` [`str`]
957 Stringified representation of a multi-index column name.
958 """
959 index_level_names = tuple(pd_index.names)
961 names: list[str | Sequence[Any]] = []
963 if isinstance(columns, list):
964 for requested in columns:
965 if not isinstance(requested, tuple):
966 raise ValueError(
967 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. "
968 f"Instead got a {get_full_type_name(requested)}."
969 )
970 if stringify:
971 names.append(str(requested))
972 else:
973 names.append(requested)
974 else:
975 if not isinstance(columns, collections.abc.Mapping):
976 raise ValueError(
977 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. "
978 f"Instead got a {get_full_type_name(columns)}."
979 )
980 if not set(index_level_names).issuperset(columns.keys()):
981 raise ValueError(
982 f"Cannot use dict with keys {set(columns.keys())} to select columns from {index_level_names}."
983 )
984 factors = [
985 ensure_iterable(columns.get(level, pd_index.levels[i]))
986 for i, level in enumerate(index_level_names)
987 ]
988 for requested in itertools.product(*factors):
989 for i, value in enumerate(requested):
990 if value not in pd_index.levels[i]:
991 raise ValueError(f"Unrecognized value {value!r} for index {index_level_names[i]!r}.")
992 if stringify:
993 names.append(str(requested))
994 else:
995 names.append(requested)
997 return names
1000def _apply_astropy_metadata(astropy_table: atable.Table, arrow_schema: pa.Schema) -> None:
1001 """Apply any astropy metadata from the schema metadata.
1003 Parameters
1004 ----------
1005 astropy_table : `astropy.table.Table`
1006 Table to apply metadata.
1007 arrow_schema : `pyarrow.Schema`
1008 Arrow schema with metadata.
1009 """
1010 from astropy.table import meta
1012 metadata = arrow_schema.metadata if arrow_schema.metadata is not None else {}
1014 # Check if we have a special astropy metadata header yaml.
1015 meta_yaml = metadata.get(b"table_meta_yaml", None)
1016 if meta_yaml:
1017 meta_yaml = meta_yaml.decode("UTF8").split("\n")
1018 meta_hdr = meta.get_header_from_yaml(meta_yaml)
1020 # Set description, format, unit, meta from the column
1021 # metadata that was serialized with the table.
1022 header_cols = {x["name"]: x for x in meta_hdr["datatype"]}
1023 for col in astropy_table.columns.values():
1024 for attr in ("description", "format", "unit", "meta"):
1025 if attr in header_cols[col.name]:
1026 setattr(col, attr, header_cols[col.name][attr])
1028 if "meta" in meta_hdr:
1029 astropy_table.meta.update(meta_hdr["meta"])
1030 else:
1031 # If we don't have astropy header data, we may have arrow field
1032 # metadata.
1033 for name in arrow_schema.names:
1034 field_metadata = arrow_schema.field(name).metadata
1035 if field_metadata is None:
1036 continue
1037 if (
1038 b"description" in field_metadata
1039 and (description := field_metadata[b"description"].decode("UTF-8")) != ""
1040 ):
1041 astropy_table[name].description = description
1042 if b"unit" in field_metadata and (unit := field_metadata[b"unit"].decode("UTF-8")) != "":
1043 astropy_table[name].unit = unit
1046def _arrow_string_to_numpy_dtype(
1047 schema: pa.Schema, name: str, numpy_column: np.ndarray | None = None, default_length: int = 10
1048) -> str:
1049 """Get the numpy dtype string associated with an arrow column.
1051 Parameters
1052 ----------
1053 schema : `pyarrow.Schema`
1054 Arrow table schema.
1055 name : `str`
1056 Column name.
1057 numpy_column : `numpy.ndarray`, optional
1058 Column to determine numpy string dtype.
1059 default_length : `int`, optional
1060 Default string length when not in metadata or can be inferred
1061 from column.
1063 Returns
1064 -------
1065 dtype_str : `str`
1066 Numpy dtype string.
1067 """
1068 # Special-case for string and binary columns
1069 md_name = f"lsst::arrow::len::{name}"
1070 strlen = default_length
1071 metadata = schema.metadata if schema.metadata is not None else {}
1072 if (encoded := md_name.encode("UTF-8")) in metadata:
1073 # String/bytes length from header.
1074 strlen = int(schema.metadata[encoded])
1075 elif numpy_column is not None and len(numpy_column) > 0:
1076 strlen = max(len(row) for row in numpy_column)
1078 dtype = f"U{strlen}" if schema.field(name).type == pa.string() else f"|S{strlen}"
1080 return dtype
1083def _append_numpy_string_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None:
1084 """Append numpy string length keys to arrow metadata.
1086 All column types are handled, but the metadata is only modified for
1087 string and byte columns.
1089 Parameters
1090 ----------
1091 metadata : `dict` [`bytes`, `str`]
1092 Metadata dictionary; modified in place.
1093 name : `str`
1094 Column name.
1095 dtype : `np.dtype`
1096 Numpy dtype.
1097 """
1098 import numpy as np
1100 if dtype.type is np.str_:
1101 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize // 4)
1102 elif dtype.type is np.bytes_:
1103 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize)
1106def _append_numpy_multidim_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None:
1107 """Append numpy multi-dimensional shapes to arrow metadata.
1109 All column types are handled, but the metadata is only modified for
1110 multi-dimensional columns.
1112 Parameters
1113 ----------
1114 metadata : `dict` [`bytes`, `str`]
1115 Metadata dictionary; modified in place.
1116 name : `str`
1117 Column name.
1118 dtype : `np.dtype`
1119 Numpy dtype.
1120 """
1121 if len(dtype.shape) > 1:
1122 metadata[f"lsst::arrow::shape::{name}".encode()] = str(dtype.shape)
1125def _multidim_shape_from_metadata(metadata: dict[bytes, bytes], list_size: int, name: str) -> tuple[int, ...]:
1126 """Retrieve the shape from the metadata, if available.
1128 Parameters
1129 ----------
1130 metadata : `dict` [`bytes`, `bytes`]
1131 Metadata dictionary.
1132 list_size : `int`
1133 Size of the list datatype.
1134 name : `str`
1135 Column name.
1137 Returns
1138 -------
1139 shape : `tuple` [`int`]
1140 Shape associated with the column.
1142 Raises
1143 ------
1144 RuntimeError
1145 Raised if metadata is found but has incorrect format.
1146 """
1147 md_name = f"lsst::arrow::shape::{name}"
1148 if (encoded := md_name.encode("UTF-8")) in metadata:
1149 groups = re.search(r"\((.*)\)", metadata[encoded].decode("UTF-8"))
1150 if groups is None:
1151 raise RuntimeError("Illegal value found in metadata.")
1152 shape = tuple(int(x) for x in groups[1].split(",") if x != "")
1153 else:
1154 shape = (list_size,)
1156 return shape
1159def _schema_to_dtype_list(schema: pa.Schema) -> list[tuple[str, tuple[Any] | str]]:
1160 """Convert a pyarrow schema to a numpy dtype.
1162 Parameters
1163 ----------
1164 schema : `pyarrow.Schema`
1165 Input pyarrow schema.
1167 Returns
1168 -------
1169 dtype_list: `list` [`tuple`]
1170 A list with name, type pairs.
1171 """
1172 metadata = schema.metadata if schema.metadata is not None else {}
1174 dtype: list[Any] = []
1175 for name in schema.names:
1176 t = schema.field(name).type
1177 if isinstance(t, pa.FixedSizeListType):
1178 shape = _multidim_shape_from_metadata(metadata, t.list_size, name)
1179 dtype.append((name, (t.value_type.to_pandas_dtype(), shape)))
1180 elif t not in (pa.string(), pa.binary()):
1181 dtype.append((name, t.to_pandas_dtype()))
1182 else:
1183 dtype.append((name, _arrow_string_to_numpy_dtype(schema, name)))
1185 return dtype
1188def _numpy_dtype_to_arrow_types(dtype: np.dtype) -> list[Any]:
1189 """Convert a numpy dtype to a list of arrow types.
1191 Parameters
1192 ----------
1193 dtype : `numpy.dtype`
1194 Numpy dtype to convert.
1196 Returns
1197 -------
1198 type_list : `list` [`object`]
1199 Converted list of arrow types.
1200 """
1201 from math import prod
1203 import numpy as np
1205 type_list: list[Any] = []
1206 if dtype.names is None:
1207 return type_list
1209 for name in dtype.names:
1210 dt = dtype[name]
1211 arrow_type: Any
1212 if len(dt.shape) > 0:
1213 arrow_type = pa.list_(
1214 pa.from_numpy_dtype(cast(tuple[np.dtype, tuple[int, ...]], dt.subdtype)[0].type),
1215 prod(dt.shape),
1216 )
1217 else:
1218 arrow_type = pa.from_numpy_dtype(dt.type)
1219 type_list.append((name, arrow_type))
1221 return type_list
1224def _numpy_dict_to_dtype(numpy_dict: dict[str, np.ndarray]) -> tuple[np.dtype, int]:
1225 """Extract equivalent table dtype from dict of numpy arrays.
1227 Parameters
1228 ----------
1229 numpy_dict : `dict` [`str`, `numpy.ndarray`]
1230 Dict with keys as the column names, values as the arrays.
1232 Returns
1233 -------
1234 dtype : `numpy.dtype`
1235 dtype of equivalent table.
1236 rowcount : `int`
1237 Number of rows in the table.
1239 Raises
1240 ------
1241 ValueError if columns in numpy_dict have unequal numbers of rows.
1242 """
1243 import numpy as np
1245 dtype_list = []
1246 rowcount = 0
1247 for name, col in numpy_dict.items():
1248 if rowcount == 0:
1249 rowcount = len(col)
1250 if len(col) != rowcount:
1251 raise ValueError(f"Column {name} has a different number of rows.")
1252 if len(col.shape) == 1:
1253 dtype_list.append((name, col.dtype))
1254 else:
1255 dtype_list.append((name, (col.dtype, col.shape[1:])))
1256 dtype = np.dtype(dtype_list)
1258 return (dtype, rowcount)
1261def _numpy_style_arrays_to_arrow_arrays(
1262 dtype: np.dtype,
1263 rowcount: int,
1264 np_style_arrays: dict[str, np.ndarray] | np.ndarray | atable.Table,
1265 schema: pa.Schema,
1266) -> list[pa.Array]:
1267 """Convert numpy-style arrays to arrow arrays.
1269 Parameters
1270 ----------
1271 dtype : `numpy.dtype`
1272 Numpy dtype of input table/arrays.
1273 rowcount : `int`
1274 Number of rows in input table/arrays.
1275 np_style_arrays : `dict` [`str`, `np.ndarray`] or `np.ndarray`
1276 or `astropy.table.Table`
1277 Arrays to convert to arrow.
1278 schema : `pyarrow.Schema`
1279 Schema of arrow table.
1281 Returns
1282 -------
1283 arrow_arrays : `list` [`pyarrow.Array`]
1284 List of converted pyarrow arrays.
1285 """
1286 import numpy as np
1288 arrow_arrays: list[pa.Array] = []
1289 if dtype.names is None:
1290 return arrow_arrays
1292 for name in dtype.names:
1293 dt = dtype[name]
1294 val: Any
1295 if len(dt.shape) > 0:
1296 if rowcount > 0:
1297 val = np.split(np_style_arrays[name].ravel(), rowcount)
1298 else:
1299 val = []
1300 else:
1301 val = np_style_arrays[name]
1303 try:
1304 arrow_arrays.append(pa.array(val, type=schema.field(name).type))
1305 except pa.ArrowNotImplementedError as err:
1306 # Check if val is big-endian.
1307 if (np.little_endian and val.dtype.byteorder == ">") or (
1308 not np.little_endian and val.dtype.byteorder == "="
1309 ):
1310 # We need to convert the array to little-endian.
1311 val2 = val.byteswap()
1312 val2.dtype = val2.dtype.newbyteorder("<")
1313 arrow_arrays.append(pa.array(val2, type=schema.field(name).type))
1314 else:
1315 # This failed for some other reason so raise the exception.
1316 raise err
1318 return arrow_arrays
1321def compute_row_group_size(schema: pa.Schema, target_size: int = TARGET_ROW_GROUP_BYTES) -> int:
1322 """Compute approximate row group size for a given arrow schema.
1324 Given a schema, this routine will compute the number of rows in a row group
1325 that targets the persisted size on disk (or smaller). The exact size on
1326 disk depends on the compression settings and ratios; typical binary data
1327 tables will have around 15-20% compression with the pyarrow default
1328 ``snappy`` compression algorithm.
1330 Parameters
1331 ----------
1332 schema : `pyarrow.Schema`
1333 Arrow table schema.
1334 target_size : `int`, optional
1335 The target size (in bytes).
1337 Returns
1338 -------
1339 row_group_size : `int`
1340 Number of rows per row group to hit the target size.
1341 """
1342 bit_width = 0
1344 metadata = schema.metadata if schema.metadata is not None else {}
1346 for name in schema.names:
1347 t = schema.field(name).type
1349 if t in (pa.string(), pa.binary()):
1350 md_name = f"lsst::arrow::len::{name}"
1352 if (encoded := md_name.encode("UTF-8")) in metadata:
1353 # String/bytes length from header.
1354 strlen = int(schema.metadata[encoded])
1355 else:
1356 # We don't know the string width, so guess something.
1357 strlen = 10
1359 # Assuming UTF-8 encoding, and very few wide characters.
1360 t_width = 8 * strlen
1361 elif isinstance(t, pa.FixedSizeListType):
1362 if t.value_type == pa.null():
1363 t_width = 0
1364 else:
1365 t_width = t.list_size * t.value_type.bit_width
1366 elif t == pa.null():
1367 t_width = 0
1368 elif isinstance(t, pa.ListType):
1369 if t.value_type == pa.null():
1370 t_width = 0
1371 else:
1372 # This is a variable length list, just choose
1373 # something arbitrary.
1374 t_width = 10 * t.value_type.bit_width
1375 else:
1376 t_width = t.bit_width
1378 bit_width += t_width
1380 # Insist it is at least 1 byte wide to avoid any divide-by-zero errors.
1381 if bit_width < 8:
1382 bit_width = 8
1384 byte_width = bit_width // 8
1386 return target_size // byte_width