Coverage for python/lsst/daf/butler/formatters/parquet.py: 13%
475 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-02 03:16 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-02 03:16 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = (
31 "ParquetFormatter",
32 "arrow_to_pandas",
33 "arrow_to_astropy",
34 "arrow_to_numpy",
35 "arrow_to_numpy_dict",
36 "pandas_to_arrow",
37 "pandas_to_astropy",
38 "astropy_to_arrow",
39 "astropy_to_pandas",
40 "numpy_to_arrow",
41 "numpy_to_astropy",
42 "numpy_dict_to_arrow",
43 "arrow_schema_to_pandas_index",
44 "DataFrameSchema",
45 "ArrowAstropySchema",
46 "ArrowNumpySchema",
47 "compute_row_group_size",
48)
50import collections.abc
51import itertools
52import json
53import re
54from collections.abc import Iterable, Sequence
55from typing import TYPE_CHECKING, Any, cast
57import pyarrow as pa
58import pyarrow.parquet as pq
59from lsst.daf.butler import Formatter
60from lsst.utils.introspection import get_full_type_name
61from lsst.utils.iteration import ensure_iterable
63if TYPE_CHECKING:
64 import astropy.table as atable
65 import numpy as np
66 import pandas as pd
68TARGET_ROW_GROUP_BYTES = 1_000_000_000
71class ParquetFormatter(Formatter):
72 """Interface for reading and writing Arrow Table objects to and from
73 Parquet files.
74 """
76 extension = ".parq"
78 def read(self, component: str | None = None) -> Any:
79 # Docstring inherited from Formatter.read.
80 schema = pq.read_schema(self.fileDescriptor.location.path)
82 schema_names = ["ArrowSchema", "DataFrameSchema", "ArrowAstropySchema", "ArrowNumpySchema"]
84 if component in ("columns", "schema") or self.fileDescriptor.readStorageClass.name in schema_names:
85 # The schema will be translated to column format
86 # depending on the input type.
87 return schema
88 elif component == "rowcount":
89 # Get the rowcount from the metadata if possible, otherwise count.
90 if b"lsst::arrow::rowcount" in schema.metadata:
91 return int(schema.metadata[b"lsst::arrow::rowcount"])
93 temp_table = pq.read_table(
94 self.fileDescriptor.location.path,
95 columns=[schema.names[0]],
96 use_threads=False,
97 use_pandas_metadata=False,
98 )
100 return len(temp_table[schema.names[0]])
102 par_columns = None
103 if self.fileDescriptor.parameters:
104 par_columns = self.fileDescriptor.parameters.pop("columns", None)
105 if par_columns:
106 has_pandas_multi_index = False
107 if b"pandas" in schema.metadata:
108 md = json.loads(schema.metadata[b"pandas"])
109 if len(md["column_indexes"]) > 1:
110 has_pandas_multi_index = True
112 if not has_pandas_multi_index:
113 # Ensure uniqueness, keeping order.
114 par_columns = list(dict.fromkeys(ensure_iterable(par_columns)))
115 file_columns = [name for name in schema.names if not name.startswith("__")]
117 for par_column in par_columns:
118 if par_column not in file_columns:
119 raise ValueError(
120 f"Column {par_column} specified in parameters not available in parquet file."
121 )
122 else:
123 par_columns = _standardize_multi_index_columns(
124 arrow_schema_to_pandas_index(schema),
125 par_columns,
126 )
128 if len(self.fileDescriptor.parameters):
129 raise ValueError(
130 f"Unsupported parameters {self.fileDescriptor.parameters} in ArrowTable read."
131 )
133 metadata = schema.metadata if schema.metadata is not None else {}
134 arrow_table = pq.read_table(
135 self.fileDescriptor.location.path,
136 columns=par_columns,
137 use_threads=False,
138 use_pandas_metadata=(b"pandas" in metadata),
139 )
141 return arrow_table
143 def write(self, inMemoryDataset: Any) -> None:
144 import numpy as np
145 from astropy.table import Table as astropyTable
147 location = self.makeUpdatedLocation(self.fileDescriptor.location)
149 arrow_table = None
150 if isinstance(inMemoryDataset, pa.Table):
151 # This will be the most likely match.
152 arrow_table = inMemoryDataset
153 elif isinstance(inMemoryDataset, astropyTable):
154 arrow_table = astropy_to_arrow(inMemoryDataset)
155 elif isinstance(inMemoryDataset, np.ndarray):
156 arrow_table = numpy_to_arrow(inMemoryDataset)
157 elif isinstance(inMemoryDataset, dict):
158 try:
159 arrow_table = numpy_dict_to_arrow(inMemoryDataset)
160 except (TypeError, AttributeError) as e:
161 raise ValueError(
162 "Input dict for inMemoryDataset does not appear to be a dict of numpy arrays."
163 ) from e
164 elif isinstance(inMemoryDataset, pa.Schema):
165 pq.write_metadata(inMemoryDataset, location.path)
166 return
167 else:
168 if hasattr(inMemoryDataset, "to_parquet"):
169 # This may be a pandas DataFrame
170 try:
171 import pandas as pd
172 except ImportError:
173 pd = None
175 if pd is not None and isinstance(inMemoryDataset, pd.DataFrame):
176 arrow_table = pandas_to_arrow(inMemoryDataset)
178 if arrow_table is None:
179 raise ValueError(
180 f"Unsupported type {get_full_type_name(inMemoryDataset)} of "
181 "inMemoryDataset for ParquetFormatter."
182 )
184 row_group_size = compute_row_group_size(arrow_table.schema)
186 pq.write_table(arrow_table, location.path, row_group_size=row_group_size)
189def arrow_to_pandas(arrow_table: pa.Table) -> pd.DataFrame:
190 """Convert a pyarrow table to a pandas DataFrame.
192 Parameters
193 ----------
194 arrow_table : `pyarrow.Table`
195 Input arrow table to convert. If the table has ``pandas`` metadata
196 in the schema it will be used in the construction of the
197 ``DataFrame``.
199 Returns
200 -------
201 dataframe : `pandas.DataFrame`
202 Converted pandas dataframe.
203 """
204 return arrow_table.to_pandas(use_threads=False, integer_object_nulls=True)
207def arrow_to_astropy(arrow_table: pa.Table) -> atable.Table:
208 """Convert a pyarrow table to an `astropy.Table`.
210 Parameters
211 ----------
212 arrow_table : `pyarrow.Table`
213 Input arrow table to convert. If the table has astropy unit
214 metadata in the schema it will be used in the construction
215 of the ``astropy.Table``.
217 Returns
218 -------
219 table : `astropy.Table`
220 Converted astropy table.
221 """
222 from astropy.table import Table
224 astropy_table = Table(arrow_to_numpy_dict(arrow_table))
226 _apply_astropy_metadata(astropy_table, arrow_table.schema)
228 return astropy_table
231def arrow_to_numpy(arrow_table: pa.Table) -> np.ndarray:
232 """Convert a pyarrow table to a structured numpy array.
234 Parameters
235 ----------
236 arrow_table : `pyarrow.Table`
237 Input arrow table.
239 Returns
240 -------
241 array : `numpy.ndarray` (N,)
242 Numpy array table with N rows and the same column names
243 as the input arrow table.
244 """
245 import numpy as np
247 numpy_dict = arrow_to_numpy_dict(arrow_table)
249 dtype = []
250 for name, col in numpy_dict.items():
251 if len(shape := numpy_dict[name].shape) <= 1:
252 dtype.append((name, col.dtype))
253 else:
254 dtype.append((name, (col.dtype, shape[1:])))
256 array = np.rec.fromarrays(numpy_dict.values(), dtype=dtype)
258 return array
261def arrow_to_numpy_dict(arrow_table: pa.Table) -> dict[str, np.ndarray]:
262 """Convert a pyarrow table to a dict of numpy arrays.
264 Parameters
265 ----------
266 arrow_table : `pyarrow.Table`
267 Input arrow table.
269 Returns
270 -------
271 numpy_dict : `dict` [`str`, `numpy.ndarray`]
272 Dict with keys as the column names, values as the arrays.
273 """
274 import numpy as np
276 schema = arrow_table.schema
277 metadata = schema.metadata if schema.metadata is not None else {}
279 numpy_dict = {}
281 for name in schema.names:
282 t = schema.field(name).type
284 if arrow_table[name].null_count == 0:
285 # Regular non-masked column
286 col = arrow_table[name].to_numpy()
287 else:
288 # For a masked column, we need to ask arrow to fill the null
289 # values with an appropriately typed value before conversion.
290 # Then we apply the mask to get a masked array of the correct type.
291 null_value: Any
292 match t:
293 case t if t in (pa.float64(), pa.float32(), pa.float16()):
294 null_value = np.nan
295 case t if t in (pa.int64(), pa.int32(), pa.int16(), pa.int8()):
296 null_value = -1
297 case t if t in (pa.bool_(),):
298 null_value = True
299 case t if t in (pa.string(), pa.binary()):
300 null_value = ""
301 case _:
302 # This is the fallback for unsigned ints in particular.
303 null_value = 0
305 col = np.ma.masked_array(
306 data=arrow_table[name].fill_null(null_value).to_numpy(),
307 mask=arrow_table[name].is_null().to_numpy(),
308 fill_value=null_value,
309 )
311 if t in (pa.string(), pa.binary()):
312 col = col.astype(_arrow_string_to_numpy_dtype(schema, name, col))
313 elif isinstance(t, pa.FixedSizeListType):
314 if len(col) > 0:
315 col = np.stack(col)
316 else:
317 # this is an empty column, and needs to be coerced to type.
318 col = col.astype(t.value_type.to_pandas_dtype())
320 shape = _multidim_shape_from_metadata(metadata, t.list_size, name)
321 col = col.reshape((len(arrow_table), *shape))
323 numpy_dict[name] = col
325 return numpy_dict
328def _numpy_dict_to_numpy(numpy_dict: dict[str, np.ndarray]) -> np.ndarray:
329 """Convert a dict of numpy arrays to a structured numpy array.
331 Parameters
332 ----------
333 numpy_dict : `dict` [`str`, `numpy.ndarray`]
334 Dict with keys as the column names, values as the arrays.
336 Returns
337 -------
338 array : `numpy.ndarray` (N,)
339 Numpy array table with N rows and columns names from the dict keys.
340 """
341 return arrow_to_numpy(numpy_dict_to_arrow(numpy_dict))
344def _numpy_to_numpy_dict(np_array: np.ndarray) -> dict[str, np.ndarray]:
345 """Convert a structured numpy array to a dict of numpy arrays.
347 Parameters
348 ----------
349 np_array : `numpy.ndarray`
350 Input numpy array with multiple fields.
352 Returns
353 -------
354 numpy_dict : `dict` [`str`, `numpy.ndarray`]
355 Dict with keys as the column names, values as the arrays.
356 """
357 return arrow_to_numpy_dict(numpy_to_arrow(np_array))
360def numpy_to_arrow(np_array: np.ndarray) -> pa.Table:
361 """Convert a numpy array table to an arrow table.
363 Parameters
364 ----------
365 np_array : `numpy.ndarray`
366 Input numpy array with multiple fields.
368 Returns
369 -------
370 arrow_table : `pyarrow.Table`
371 Converted arrow table.
372 """
373 type_list = _numpy_dtype_to_arrow_types(np_array.dtype)
375 md = {}
376 md[b"lsst::arrow::rowcount"] = str(len(np_array))
378 for name in np_array.dtype.names:
379 _append_numpy_string_metadata(md, name, np_array.dtype[name])
380 _append_numpy_multidim_metadata(md, name, np_array.dtype[name])
382 schema = pa.schema(type_list, metadata=md)
384 arrays = _numpy_style_arrays_to_arrow_arrays(
385 np_array.dtype,
386 len(np_array),
387 np_array,
388 schema,
389 )
391 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
393 return arrow_table
396def numpy_dict_to_arrow(numpy_dict: dict[str, np.ndarray]) -> pa.Table:
397 """Convert a dict of numpy arrays to an arrow table.
399 Parameters
400 ----------
401 numpy_dict : `dict` [`str`, `numpy.ndarray`]
402 Dict with keys as the column names, values as the arrays.
404 Returns
405 -------
406 arrow_table : `pyarrow.Table`
407 Converted arrow table.
409 Raises
410 ------
411 ValueError if columns in numpy_dict have unequal numbers of rows.
412 """
413 dtype, rowcount = _numpy_dict_to_dtype(numpy_dict)
414 type_list = _numpy_dtype_to_arrow_types(dtype)
416 md = {}
417 md[b"lsst::arrow::rowcount"] = str(rowcount)
419 if dtype.names is not None:
420 for name in dtype.names:
421 _append_numpy_string_metadata(md, name, dtype[name])
422 _append_numpy_multidim_metadata(md, name, dtype[name])
424 schema = pa.schema(type_list, metadata=md)
426 arrays = _numpy_style_arrays_to_arrow_arrays(
427 dtype,
428 rowcount,
429 numpy_dict,
430 schema,
431 )
433 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
435 return arrow_table
438def astropy_to_arrow(astropy_table: atable.Table) -> pa.Table:
439 """Convert an astropy table to an arrow table.
441 Parameters
442 ----------
443 astropy_table : `astropy.Table`
444 Input astropy table.
446 Returns
447 -------
448 arrow_table : `pyarrow.Table`
449 Converted arrow table.
450 """
451 from astropy.table import meta
453 type_list = _numpy_dtype_to_arrow_types(astropy_table.dtype)
455 md = {}
456 md[b"lsst::arrow::rowcount"] = str(len(astropy_table))
458 for name in astropy_table.dtype.names:
459 _append_numpy_string_metadata(md, name, astropy_table.dtype[name])
460 _append_numpy_multidim_metadata(md, name, astropy_table.dtype[name])
462 meta_yaml = meta.get_yaml_from_table(astropy_table)
463 meta_yaml_str = "\n".join(meta_yaml)
464 md[b"table_meta_yaml"] = meta_yaml_str
466 # Convert type list to fields with metadata.
467 fields = []
468 for name, pa_type in type_list:
469 field_metadata = {}
470 if description := astropy_table[name].description:
471 field_metadata["description"] = description
472 if unit := astropy_table[name].unit:
473 field_metadata["unit"] = str(unit)
474 fields.append(
475 pa.field(
476 name,
477 pa_type,
478 metadata=field_metadata,
479 )
480 )
482 schema = pa.schema(fields, metadata=md)
484 arrays = _numpy_style_arrays_to_arrow_arrays(
485 astropy_table.dtype,
486 len(astropy_table),
487 astropy_table,
488 schema,
489 )
491 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
493 return arrow_table
496def astropy_to_pandas(astropy_table: atable.Table, index: str | None = None) -> pd.DataFrame:
497 """Convert an astropy table to a pandas dataframe via arrow.
499 By going via arrow we avoid pandas masked column bugs (e.g.
500 https://github.com/pandas-dev/pandas/issues/58173)
502 Parameters
503 ----------
504 astropy_table : `astropy.Table`
505 Input astropy table.
506 index : `str`, optional
507 Name of column to set as index.
509 Returns
510 -------
511 dataframe : `pandas.DataFrame`
512 Output pandas dataframe.
513 """
514 dataframe = arrow_to_pandas(astropy_to_arrow(astropy_table))
516 if isinstance(index, str):
517 dataframe = dataframe.set_index(index)
518 elif index:
519 raise RuntimeError("index must be a string or None.")
521 return dataframe
524def _astropy_to_numpy_dict(astropy_table: atable.Table) -> dict[str, np.ndarray]:
525 """Convert an astropy table to an arrow table.
527 Parameters
528 ----------
529 astropy_table : `astropy.Table`
530 Input astropy table.
532 Returns
533 -------
534 numpy_dict : `dict` [`str`, `numpy.ndarray`]
535 Dict with keys as the column names, values as the arrays.
536 """
537 return arrow_to_numpy_dict(astropy_to_arrow(astropy_table))
540def pandas_to_arrow(dataframe: pd.DataFrame, default_length: int = 10) -> pa.Table:
541 """Convert a pandas dataframe to an arrow table.
543 Parameters
544 ----------
545 dataframe : `pandas.DataFrame`
546 Input pandas dataframe.
547 default_length : `int`, optional
548 Default string length when not in metadata or can be inferred
549 from column.
551 Returns
552 -------
553 arrow_table : `pyarrow.Table`
554 Converted arrow table.
555 """
556 arrow_table = pa.Table.from_pandas(dataframe)
558 # Update the metadata
559 md = arrow_table.schema.metadata
561 md[b"lsst::arrow::rowcount"] = str(arrow_table.num_rows)
563 # We loop through the arrow table columns because the datatypes have
564 # been checked and converted from pandas objects.
565 for name in arrow_table.column_names:
566 if not name.startswith("__") and arrow_table[name].type == pa.string():
567 if len(arrow_table[name]) > 0:
568 strlen = max(len(row.as_py()) for row in arrow_table[name] if row.is_valid)
569 else:
570 strlen = default_length
571 md[f"lsst::arrow::len::{name}".encode()] = str(strlen)
573 arrow_table = arrow_table.replace_schema_metadata(md)
575 return arrow_table
578def pandas_to_astropy(dataframe: pd.DataFrame) -> atable.Table:
579 """Convert a pandas dataframe to an astropy table, preserving indexes.
581 Parameters
582 ----------
583 dataframe : `pandas.DataFrame`
584 Input pandas dataframe.
586 Returns
587 -------
588 astropy_table : `astropy.table.Table`
589 Converted astropy table.
590 """
591 import pandas as pd
593 if isinstance(dataframe.columns, pd.MultiIndex):
594 raise ValueError("Cannot convert a multi-index dataframe to an astropy table.")
596 return arrow_to_astropy(pandas_to_arrow(dataframe))
599def _pandas_to_numpy_dict(dataframe: pd.DataFrame) -> dict[str, np.ndarray]:
600 """Convert a pandas dataframe to an dict of numpy arrays.
602 Parameters
603 ----------
604 dataframe : `pandas.DataFrame`
605 Input pandas dataframe.
607 Returns
608 -------
609 numpy_dict : `dict` [`str`, `numpy.ndarray`]
610 Dict with keys as the column names, values as the arrays.
611 """
612 return arrow_to_numpy_dict(pandas_to_arrow(dataframe))
615def numpy_to_astropy(np_array: np.ndarray) -> atable.Table:
616 """Convert a numpy table to an astropy table.
618 Parameters
619 ----------
620 np_array : `numpy.ndarray`
621 Input numpy array with multiple fields.
623 Returns
624 -------
625 astropy_table : `astropy.table.Table`
626 Converted astropy table.
627 """
628 from astropy.table import Table
630 return Table(data=np_array, copy=False)
633def arrow_schema_to_pandas_index(schema: pa.Schema) -> pd.Index | pd.MultiIndex:
634 """Convert an arrow schema to a pandas index/multiindex.
636 Parameters
637 ----------
638 schema : `pyarrow.Schema`
639 Input pyarrow schema.
641 Returns
642 -------
643 index : `pandas.Index` or `pandas.MultiIndex`
644 Converted pandas index.
645 """
646 import pandas as pd
648 if b"pandas" in schema.metadata:
649 md = json.loads(schema.metadata[b"pandas"])
650 indexes = md["column_indexes"]
651 len_indexes = len(indexes)
652 else:
653 len_indexes = 0
655 if len_indexes <= 1:
656 return pd.Index(name for name in schema.names if not name.startswith("__"))
657 else:
658 raw_columns = _split_multi_index_column_names(len(indexes), schema.names)
659 return pd.MultiIndex.from_tuples(raw_columns, names=[f["name"] for f in indexes])
662def arrow_schema_to_column_list(schema: pa.Schema) -> list[str]:
663 """Convert an arrow schema to a list of string column names.
665 Parameters
666 ----------
667 schema : `pyarrow.Schema`
668 Input pyarrow schema.
670 Returns
671 -------
672 column_list : `list` [`str`]
673 Converted list of column names.
674 """
675 return list(schema.names)
678class DataFrameSchema:
679 """Wrapper class for a schema for a pandas DataFrame.
681 Parameters
682 ----------
683 dataframe : `pandas.DataFrame`
684 Dataframe to turn into a schema.
685 """
687 def __init__(self, dataframe: pd.DataFrame) -> None:
688 self._schema = dataframe.loc[[False] * len(dataframe)]
690 @classmethod
691 def from_arrow(cls, schema: pa.Schema) -> DataFrameSchema:
692 """Convert an arrow schema into a `DataFrameSchema`.
694 Parameters
695 ----------
696 schema : `pyarrow.Schema`
697 The pyarrow schema to convert.
699 Returns
700 -------
701 dataframe_schema : `DataFrameSchema`
702 Converted dataframe schema.
703 """
704 empty_table = pa.Table.from_pylist([] * len(schema.names), schema=schema)
706 return cls(empty_table.to_pandas())
708 def to_arrow_schema(self) -> pa.Schema:
709 """Convert to an arrow schema.
711 Returns
712 -------
713 arrow_schema : `pyarrow.Schema`
714 Converted pyarrow schema.
715 """
716 arrow_table = pa.Table.from_pandas(self._schema)
718 return arrow_table.schema
720 def to_arrow_numpy_schema(self) -> ArrowNumpySchema:
721 """Convert to an `ArrowNumpySchema`.
723 Returns
724 -------
725 arrow_numpy_schema : `ArrowNumpySchema`
726 Converted arrow numpy schema.
727 """
728 return ArrowNumpySchema.from_arrow(self.to_arrow_schema())
730 def to_arrow_astropy_schema(self) -> ArrowAstropySchema:
731 """Convert to an ArrowAstropySchema.
733 Returns
734 -------
735 arrow_astropy_schema : `ArrowAstropySchema`
736 Converted arrow astropy schema.
737 """
738 return ArrowAstropySchema.from_arrow(self.to_arrow_schema())
740 @property
741 def schema(self) -> np.dtype:
742 return self._schema
744 def __repr__(self) -> str:
745 return repr(self._schema)
747 def __eq__(self, other: object) -> bool:
748 if not isinstance(other, DataFrameSchema):
749 return NotImplemented
751 return self._schema.equals(other._schema)
754class ArrowAstropySchema:
755 """Wrapper class for a schema for an astropy table.
757 Parameters
758 ----------
759 astropy_table : `astropy.table.Table`
760 Input astropy table.
761 """
763 def __init__(self, astropy_table: atable.Table) -> None:
764 self._schema = astropy_table[:0]
766 @classmethod
767 def from_arrow(cls, schema: pa.Schema) -> ArrowAstropySchema:
768 """Convert an arrow schema into a ArrowAstropySchema.
770 Parameters
771 ----------
772 schema : `pyarrow.Schema`
773 Input pyarrow schema.
775 Returns
776 -------
777 astropy_schema : `ArrowAstropySchema`
778 Converted arrow astropy schema.
779 """
780 import numpy as np
781 from astropy.table import Table
783 dtype = _schema_to_dtype_list(schema)
785 data = np.zeros(0, dtype=dtype)
786 astropy_table = Table(data=data)
788 _apply_astropy_metadata(astropy_table, schema)
790 return cls(astropy_table)
792 def to_arrow_schema(self) -> pa.Schema:
793 """Convert to an arrow schema.
795 Returns
796 -------
797 arrow_schema : `pyarrow.Schema`
798 Converted pyarrow schema.
799 """
800 return astropy_to_arrow(self._schema).schema
802 def to_dataframe_schema(self) -> DataFrameSchema:
803 """Convert to a DataFrameSchema.
805 Returns
806 -------
807 dataframe_schema : `DataFrameSchema`
808 Converted dataframe schema.
809 """
810 return DataFrameSchema.from_arrow(astropy_to_arrow(self._schema).schema)
812 def to_arrow_numpy_schema(self) -> ArrowNumpySchema:
813 """Convert to an `ArrowNumpySchema`.
815 Returns
816 -------
817 arrow_numpy_schema : `ArrowNumpySchema`
818 Converted arrow numpy schema.
819 """
820 return ArrowNumpySchema.from_arrow(astropy_to_arrow(self._schema).schema)
822 @property
823 def schema(self) -> atable.Table:
824 return self._schema
826 def __repr__(self) -> str:
827 return repr(self._schema)
829 def __eq__(self, other: object) -> bool:
830 if not isinstance(other, ArrowAstropySchema):
831 return NotImplemented
833 # If this comparison passes then the two tables have the
834 # same column names.
835 if self._schema.dtype != other._schema.dtype:
836 return False
838 for name in self._schema.columns:
839 if not self._schema[name].unit == other._schema[name].unit:
840 return False
841 if not self._schema[name].description == other._schema[name].description:
842 return False
843 if not self._schema[name].format == other._schema[name].format:
844 return False
846 return True
849class ArrowNumpySchema:
850 """Wrapper class for a schema for a numpy ndarray.
852 Parameters
853 ----------
854 numpy_dtype : `numpy.dtype`
855 Numpy dtype to convert.
856 """
858 def __init__(self, numpy_dtype: np.dtype) -> None:
859 self._dtype = numpy_dtype
861 @classmethod
862 def from_arrow(cls, schema: pa.Schema) -> ArrowNumpySchema:
863 """Convert an arrow schema into an `ArrowNumpySchema`.
865 Parameters
866 ----------
867 schema : `pyarrow.Schema`
868 Pyarrow schema to convert.
870 Returns
871 -------
872 numpy_schema : `ArrowNumpySchema`
873 Converted arrow numpy schema.
874 """
875 import numpy as np
877 dtype = _schema_to_dtype_list(schema)
879 return cls(np.dtype(dtype))
881 def to_arrow_astropy_schema(self) -> ArrowAstropySchema:
882 """Convert to an `ArrowAstropySchema`.
884 Returns
885 -------
886 astropy_schema : `ArrowAstropySchema`
887 Converted arrow astropy schema.
888 """
889 import numpy as np
891 return ArrowAstropySchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema)
893 def to_dataframe_schema(self) -> DataFrameSchema:
894 """Convert to a `DataFrameSchema`.
896 Returns
897 -------
898 dataframe_schema : `DataFrameSchema`
899 Converted dataframe schema.
900 """
901 import numpy as np
903 return DataFrameSchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema)
905 def to_arrow_schema(self) -> pa.Schema:
906 """Convert to a `pyarrow.Schema`.
908 Returns
909 -------
910 arrow_schema : `pyarrow.Schema`
911 Converted pyarrow schema.
912 """
913 import numpy as np
915 return numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema
917 @property
918 def schema(self) -> np.dtype:
919 return self._dtype
921 def __repr__(self) -> str:
922 return repr(self._dtype)
924 def __eq__(self, other: object) -> bool:
925 if not isinstance(other, ArrowNumpySchema):
926 return NotImplemented
928 if not self._dtype == other._dtype:
929 return False
931 return True
934def _split_multi_index_column_names(n: int, names: Iterable[str]) -> list[Sequence[str]]:
935 """Split a string that represents a multi-index column.
937 PyArrow maps Pandas' multi-index column names (which are tuples in Python)
938 to flat strings on disk. This routine exists to reconstruct the original
939 tuple.
941 Parameters
942 ----------
943 n : `int`
944 Number of levels in the `pandas.MultiIndex` that is being
945 reconstructed.
946 names : `~collections.abc.Iterable` [`str`]
947 Strings to be split.
949 Returns
950 -------
951 column_names : `list` [`tuple` [`str`]]
952 A list of multi-index column name tuples.
953 """
954 column_names: list[Sequence[str]] = []
956 pattern = re.compile(r"\({}\)".format(", ".join(["'(.*)'"] * n)))
957 for name in names:
958 m = re.search(pattern, name)
959 if m is not None:
960 column_names.append(m.groups())
962 return column_names
965def _standardize_multi_index_columns(
966 pd_index: pd.MultiIndex,
967 columns: Any,
968 stringify: bool = True,
969) -> list[str | Sequence[Any]]:
970 """Transform a dictionary/iterable index from a multi-index column list
971 into a string directly understandable by PyArrow.
973 Parameters
974 ----------
975 pd_index : `pandas.MultiIndex`
976 Pandas multi-index.
977 columns : `list` [`tuple`] or `dict` [`str`, `str` or `list` [`str`]]
978 Columns to standardize.
979 stringify : `bool`, optional
980 Should the column names be stringified?
982 Returns
983 -------
984 names : `list` [`str`]
985 Stringified representation of a multi-index column name.
986 """
987 index_level_names = tuple(pd_index.names)
989 names: list[str | Sequence[Any]] = []
991 if isinstance(columns, list):
992 for requested in columns:
993 if not isinstance(requested, tuple):
994 raise ValueError(
995 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. "
996 f"Instead got a {get_full_type_name(requested)}."
997 )
998 if stringify:
999 names.append(str(requested))
1000 else:
1001 names.append(requested)
1002 else:
1003 if not isinstance(columns, collections.abc.Mapping):
1004 raise ValueError(
1005 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. "
1006 f"Instead got a {get_full_type_name(columns)}."
1007 )
1008 if not set(index_level_names).issuperset(columns.keys()):
1009 raise ValueError(
1010 f"Cannot use dict with keys {set(columns.keys())} to select columns from {index_level_names}."
1011 )
1012 factors = [
1013 ensure_iterable(columns.get(level, pd_index.levels[i]))
1014 for i, level in enumerate(index_level_names)
1015 ]
1016 for requested in itertools.product(*factors):
1017 for i, value in enumerate(requested):
1018 if value not in pd_index.levels[i]:
1019 raise ValueError(f"Unrecognized value {value!r} for index {index_level_names[i]!r}.")
1020 if stringify:
1021 names.append(str(requested))
1022 else:
1023 names.append(requested)
1025 return names
1028def _apply_astropy_metadata(astropy_table: atable.Table, arrow_schema: pa.Schema) -> None:
1029 """Apply any astropy metadata from the schema metadata.
1031 Parameters
1032 ----------
1033 astropy_table : `astropy.table.Table`
1034 Table to apply metadata.
1035 arrow_schema : `pyarrow.Schema`
1036 Arrow schema with metadata.
1037 """
1038 from astropy.table import meta
1040 metadata = arrow_schema.metadata if arrow_schema.metadata is not None else {}
1042 # Check if we have a special astropy metadata header yaml.
1043 meta_yaml = metadata.get(b"table_meta_yaml", None)
1044 if meta_yaml:
1045 meta_yaml = meta_yaml.decode("UTF8").split("\n")
1046 meta_hdr = meta.get_header_from_yaml(meta_yaml)
1048 # Set description, format, unit, meta from the column
1049 # metadata that was serialized with the table.
1050 header_cols = {x["name"]: x for x in meta_hdr["datatype"]}
1051 for col in astropy_table.columns.values():
1052 for attr in ("description", "format", "unit", "meta"):
1053 if attr in header_cols[col.name]:
1054 setattr(col, attr, header_cols[col.name][attr])
1056 if "meta" in meta_hdr:
1057 astropy_table.meta.update(meta_hdr["meta"])
1058 else:
1059 # If we don't have astropy header data, we may have arrow field
1060 # metadata.
1061 for name in arrow_schema.names:
1062 field_metadata = arrow_schema.field(name).metadata
1063 if field_metadata is None:
1064 continue
1065 if (
1066 b"description" in field_metadata
1067 and (description := field_metadata[b"description"].decode("UTF-8")) != ""
1068 ):
1069 astropy_table[name].description = description
1070 if b"unit" in field_metadata and (unit := field_metadata[b"unit"].decode("UTF-8")) != "":
1071 astropy_table[name].unit = unit
1074def _arrow_string_to_numpy_dtype(
1075 schema: pa.Schema, name: str, numpy_column: np.ndarray | None = None, default_length: int = 10
1076) -> str:
1077 """Get the numpy dtype string associated with an arrow column.
1079 Parameters
1080 ----------
1081 schema : `pyarrow.Schema`
1082 Arrow table schema.
1083 name : `str`
1084 Column name.
1085 numpy_column : `numpy.ndarray`, optional
1086 Column to determine numpy string dtype.
1087 default_length : `int`, optional
1088 Default string length when not in metadata or can be inferred
1089 from column.
1091 Returns
1092 -------
1093 dtype_str : `str`
1094 Numpy dtype string.
1095 """
1096 # Special-case for string and binary columns
1097 md_name = f"lsst::arrow::len::{name}"
1098 strlen = default_length
1099 metadata = schema.metadata if schema.metadata is not None else {}
1100 if (encoded := md_name.encode("UTF-8")) in metadata:
1101 # String/bytes length from header.
1102 strlen = int(schema.metadata[encoded])
1103 elif numpy_column is not None and len(numpy_column) > 0:
1104 strlen = max([len(row) for row in numpy_column if row])
1106 dtype = f"U{strlen}" if schema.field(name).type == pa.string() else f"|S{strlen}"
1108 return dtype
1111def _append_numpy_string_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None:
1112 """Append numpy string length keys to arrow metadata.
1114 All column types are handled, but the metadata is only modified for
1115 string and byte columns.
1117 Parameters
1118 ----------
1119 metadata : `dict` [`bytes`, `str`]
1120 Metadata dictionary; modified in place.
1121 name : `str`
1122 Column name.
1123 dtype : `np.dtype`
1124 Numpy dtype.
1125 """
1126 import numpy as np
1128 if dtype.type is np.str_:
1129 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize // 4)
1130 elif dtype.type is np.bytes_:
1131 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize)
1134def _append_numpy_multidim_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None:
1135 """Append numpy multi-dimensional shapes to arrow metadata.
1137 All column types are handled, but the metadata is only modified for
1138 multi-dimensional columns.
1140 Parameters
1141 ----------
1142 metadata : `dict` [`bytes`, `str`]
1143 Metadata dictionary; modified in place.
1144 name : `str`
1145 Column name.
1146 dtype : `np.dtype`
1147 Numpy dtype.
1148 """
1149 if len(dtype.shape) > 1:
1150 metadata[f"lsst::arrow::shape::{name}".encode()] = str(dtype.shape)
1153def _multidim_shape_from_metadata(metadata: dict[bytes, bytes], list_size: int, name: str) -> tuple[int, ...]:
1154 """Retrieve the shape from the metadata, if available.
1156 Parameters
1157 ----------
1158 metadata : `dict` [`bytes`, `bytes`]
1159 Metadata dictionary.
1160 list_size : `int`
1161 Size of the list datatype.
1162 name : `str`
1163 Column name.
1165 Returns
1166 -------
1167 shape : `tuple` [`int`]
1168 Shape associated with the column.
1170 Raises
1171 ------
1172 RuntimeError
1173 Raised if metadata is found but has incorrect format.
1174 """
1175 md_name = f"lsst::arrow::shape::{name}"
1176 if (encoded := md_name.encode("UTF-8")) in metadata:
1177 groups = re.search(r"\((.*)\)", metadata[encoded].decode("UTF-8"))
1178 if groups is None:
1179 raise RuntimeError("Illegal value found in metadata.")
1180 shape = tuple(int(x) for x in groups[1].split(",") if x != "")
1181 else:
1182 shape = (list_size,)
1184 return shape
1187def _schema_to_dtype_list(schema: pa.Schema) -> list[tuple[str, tuple[Any] | str]]:
1188 """Convert a pyarrow schema to a numpy dtype.
1190 Parameters
1191 ----------
1192 schema : `pyarrow.Schema`
1193 Input pyarrow schema.
1195 Returns
1196 -------
1197 dtype_list: `list` [`tuple`]
1198 A list with name, type pairs.
1199 """
1200 metadata = schema.metadata if schema.metadata is not None else {}
1202 dtype: list[Any] = []
1203 for name in schema.names:
1204 t = schema.field(name).type
1205 if isinstance(t, pa.FixedSizeListType):
1206 shape = _multidim_shape_from_metadata(metadata, t.list_size, name)
1207 dtype.append((name, (t.value_type.to_pandas_dtype(), shape)))
1208 elif t not in (pa.string(), pa.binary()):
1209 dtype.append((name, t.to_pandas_dtype()))
1210 else:
1211 dtype.append((name, _arrow_string_to_numpy_dtype(schema, name)))
1213 return dtype
1216def _numpy_dtype_to_arrow_types(dtype: np.dtype) -> list[Any]:
1217 """Convert a numpy dtype to a list of arrow types.
1219 Parameters
1220 ----------
1221 dtype : `numpy.dtype`
1222 Numpy dtype to convert.
1224 Returns
1225 -------
1226 type_list : `list` [`object`]
1227 Converted list of arrow types.
1228 """
1229 from math import prod
1231 import numpy as np
1233 type_list: list[Any] = []
1234 if dtype.names is None:
1235 return type_list
1237 for name in dtype.names:
1238 dt = dtype[name]
1239 arrow_type: Any
1240 if len(dt.shape) > 0:
1241 arrow_type = pa.list_(
1242 pa.from_numpy_dtype(cast(tuple[np.dtype, tuple[int, ...]], dt.subdtype)[0].type),
1243 prod(dt.shape),
1244 )
1245 else:
1246 arrow_type = pa.from_numpy_dtype(dt.type)
1247 type_list.append((name, arrow_type))
1249 return type_list
1252def _numpy_dict_to_dtype(numpy_dict: dict[str, np.ndarray]) -> tuple[np.dtype, int]:
1253 """Extract equivalent table dtype from dict of numpy arrays.
1255 Parameters
1256 ----------
1257 numpy_dict : `dict` [`str`, `numpy.ndarray`]
1258 Dict with keys as the column names, values as the arrays.
1260 Returns
1261 -------
1262 dtype : `numpy.dtype`
1263 dtype of equivalent table.
1264 rowcount : `int`
1265 Number of rows in the table.
1267 Raises
1268 ------
1269 ValueError if columns in numpy_dict have unequal numbers of rows.
1270 """
1271 import numpy as np
1273 dtype_list = []
1274 rowcount = 0
1275 for name, col in numpy_dict.items():
1276 if rowcount == 0:
1277 rowcount = len(col)
1278 if len(col) != rowcount:
1279 raise ValueError(f"Column {name} has a different number of rows.")
1280 if len(col.shape) == 1:
1281 dtype_list.append((name, col.dtype))
1282 else:
1283 dtype_list.append((name, (col.dtype, col.shape[1:])))
1284 dtype = np.dtype(dtype_list)
1286 return (dtype, rowcount)
1289def _numpy_style_arrays_to_arrow_arrays(
1290 dtype: np.dtype,
1291 rowcount: int,
1292 np_style_arrays: dict[str, np.ndarray] | np.ndarray | atable.Table,
1293 schema: pa.Schema,
1294) -> list[pa.Array]:
1295 """Convert numpy-style arrays to arrow arrays.
1297 Parameters
1298 ----------
1299 dtype : `numpy.dtype`
1300 Numpy dtype of input table/arrays.
1301 rowcount : `int`
1302 Number of rows in input table/arrays.
1303 np_style_arrays : `dict` [`str`, `np.ndarray`] or `np.ndarray`
1304 or `astropy.table.Table`
1305 Arrays to convert to arrow.
1306 schema : `pyarrow.Schema`
1307 Schema of arrow table.
1309 Returns
1310 -------
1311 arrow_arrays : `list` [`pyarrow.Array`]
1312 List of converted pyarrow arrays.
1313 """
1314 import numpy as np
1316 arrow_arrays: list[pa.Array] = []
1317 if dtype.names is None:
1318 return arrow_arrays
1320 for name in dtype.names:
1321 dt = dtype[name]
1322 val: Any
1323 if len(dt.shape) > 0:
1324 if rowcount > 0:
1325 val = np.split(np_style_arrays[name].ravel(), rowcount)
1326 else:
1327 val = []
1328 else:
1329 val = np_style_arrays[name]
1331 try:
1332 arrow_arrays.append(pa.array(val, type=schema.field(name).type))
1333 except pa.ArrowNotImplementedError as err:
1334 # Check if val is big-endian.
1335 if (np.little_endian and val.dtype.byteorder == ">") or (
1336 not np.little_endian and val.dtype.byteorder == "="
1337 ):
1338 # We need to convert the array to little-endian.
1339 val2 = val.byteswap()
1340 val2.dtype = val2.dtype.newbyteorder("<")
1341 arrow_arrays.append(pa.array(val2, type=schema.field(name).type))
1342 else:
1343 # This failed for some other reason so raise the exception.
1344 raise err
1346 return arrow_arrays
1349def compute_row_group_size(schema: pa.Schema, target_size: int = TARGET_ROW_GROUP_BYTES) -> int:
1350 """Compute approximate row group size for a given arrow schema.
1352 Given a schema, this routine will compute the number of rows in a row group
1353 that targets the persisted size on disk (or smaller). The exact size on
1354 disk depends on the compression settings and ratios; typical binary data
1355 tables will have around 15-20% compression with the pyarrow default
1356 ``snappy`` compression algorithm.
1358 Parameters
1359 ----------
1360 schema : `pyarrow.Schema`
1361 Arrow table schema.
1362 target_size : `int`, optional
1363 The target size (in bytes).
1365 Returns
1366 -------
1367 row_group_size : `int`
1368 Number of rows per row group to hit the target size.
1369 """
1370 bit_width = 0
1372 metadata = schema.metadata if schema.metadata is not None else {}
1374 for name in schema.names:
1375 t = schema.field(name).type
1377 if t in (pa.string(), pa.binary()):
1378 md_name = f"lsst::arrow::len::{name}"
1380 if (encoded := md_name.encode("UTF-8")) in metadata:
1381 # String/bytes length from header.
1382 strlen = int(schema.metadata[encoded])
1383 else:
1384 # We don't know the string width, so guess something.
1385 strlen = 10
1387 # Assuming UTF-8 encoding, and very few wide characters.
1388 t_width = 8 * strlen
1389 elif isinstance(t, pa.FixedSizeListType):
1390 if t.value_type == pa.null():
1391 t_width = 0
1392 else:
1393 t_width = t.list_size * t.value_type.bit_width
1394 elif t == pa.null():
1395 t_width = 0
1396 elif isinstance(t, pa.ListType):
1397 if t.value_type == pa.null():
1398 t_width = 0
1399 else:
1400 # This is a variable length list, just choose
1401 # something arbitrary.
1402 t_width = 10 * t.value_type.bit_width
1403 else:
1404 t_width = t.bit_width
1406 bit_width += t_width
1408 # Insist it is at least 1 byte wide to avoid any divide-by-zero errors.
1409 if bit_width < 8:
1410 bit_width = 8
1412 byte_width = bit_width // 8
1414 return target_size // byte_width