Coverage for python/lsst/daf/butler/formatters/parquet.py: 13%
473 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-26 02:48 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-26 02:48 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = (
31 "ParquetFormatter",
32 "arrow_to_pandas",
33 "arrow_to_astropy",
34 "arrow_to_numpy",
35 "arrow_to_numpy_dict",
36 "pandas_to_arrow",
37 "pandas_to_astropy",
38 "astropy_to_arrow",
39 "numpy_to_arrow",
40 "numpy_to_astropy",
41 "numpy_dict_to_arrow",
42 "arrow_schema_to_pandas_index",
43 "DataFrameSchema",
44 "ArrowAstropySchema",
45 "ArrowNumpySchema",
46 "compute_row_group_size",
47)
49import collections.abc
50import itertools
51import json
52import re
53from collections.abc import Iterable, Sequence
54from typing import TYPE_CHECKING, Any, cast
56import pyarrow as pa
57import pyarrow.parquet as pq
58from lsst.daf.butler import Formatter
59from lsst.utils.introspection import get_full_type_name
60from lsst.utils.iteration import ensure_iterable
62if TYPE_CHECKING:
63 import astropy.table as atable
64 import numpy as np
65 import pandas as pd
67TARGET_ROW_GROUP_BYTES = 1_000_000_000
70class ParquetFormatter(Formatter):
71 """Interface for reading and writing Arrow Table objects to and from
72 Parquet files.
73 """
75 extension = ".parq"
77 def read(self, component: str | None = None) -> Any:
78 # Docstring inherited from Formatter.read.
79 schema = pq.read_schema(self.fileDescriptor.location.path)
81 schema_names = ["ArrowSchema", "DataFrameSchema", "ArrowAstropySchema", "ArrowNumpySchema"]
83 if component in ("columns", "schema") or self.fileDescriptor.readStorageClass.name in schema_names:
84 # The schema will be translated to column format
85 # depending on the input type.
86 return schema
87 elif component == "rowcount":
88 # Get the rowcount from the metadata if possible, otherwise count.
89 if b"lsst::arrow::rowcount" in schema.metadata:
90 return int(schema.metadata[b"lsst::arrow::rowcount"])
92 temp_table = pq.read_table(
93 self.fileDescriptor.location.path,
94 columns=[schema.names[0]],
95 use_threads=False,
96 use_pandas_metadata=False,
97 )
99 return len(temp_table[schema.names[0]])
101 par_columns = None
102 if self.fileDescriptor.parameters:
103 par_columns = self.fileDescriptor.parameters.pop("columns", None)
104 if par_columns:
105 has_pandas_multi_index = False
106 if b"pandas" in schema.metadata:
107 md = json.loads(schema.metadata[b"pandas"])
108 if len(md["column_indexes"]) > 1:
109 has_pandas_multi_index = True
111 if not has_pandas_multi_index:
112 # Ensure uniqueness, keeping order.
113 par_columns = list(dict.fromkeys(ensure_iterable(par_columns)))
114 file_columns = [name for name in schema.names if not name.startswith("__")]
116 for par_column in par_columns:
117 if par_column not in file_columns:
118 raise ValueError(
119 f"Column {par_column} specified in parameters not available in parquet file."
120 )
121 else:
122 par_columns = _standardize_multi_index_columns(
123 arrow_schema_to_pandas_index(schema),
124 par_columns,
125 )
127 if len(self.fileDescriptor.parameters):
128 raise ValueError(
129 f"Unsupported parameters {self.fileDescriptor.parameters} in ArrowTable read."
130 )
132 metadata = schema.metadata if schema.metadata is not None else {}
133 arrow_table = pq.read_table(
134 self.fileDescriptor.location.path,
135 columns=par_columns,
136 use_threads=False,
137 use_pandas_metadata=(b"pandas" in metadata),
138 )
140 return arrow_table
142 def write(self, inMemoryDataset: Any) -> None:
143 import numpy as np
144 from astropy.table import Table as astropyTable
146 location = self.makeUpdatedLocation(self.fileDescriptor.location)
148 arrow_table = None
149 if isinstance(inMemoryDataset, pa.Table):
150 # This will be the most likely match.
151 arrow_table = inMemoryDataset
152 elif isinstance(inMemoryDataset, astropyTable):
153 arrow_table = astropy_to_arrow(inMemoryDataset)
154 elif isinstance(inMemoryDataset, np.ndarray):
155 arrow_table = numpy_to_arrow(inMemoryDataset)
156 elif isinstance(inMemoryDataset, dict):
157 try:
158 arrow_table = numpy_dict_to_arrow(inMemoryDataset)
159 except (TypeError, AttributeError) as e:
160 raise ValueError(
161 "Input dict for inMemoryDataset does not appear to be a dict of numpy arrays."
162 ) from e
163 elif isinstance(inMemoryDataset, pa.Schema):
164 pq.write_metadata(inMemoryDataset, location.path)
165 return
166 else:
167 if hasattr(inMemoryDataset, "to_parquet"):
168 # This may be a pandas DataFrame
169 try:
170 import pandas as pd
171 except ImportError:
172 pd = None
174 if pd is not None and isinstance(inMemoryDataset, pd.DataFrame):
175 arrow_table = pandas_to_arrow(inMemoryDataset)
177 if arrow_table is None:
178 raise ValueError(
179 f"Unsupported type {get_full_type_name(inMemoryDataset)} of "
180 "inMemoryDataset for ParquetFormatter."
181 )
183 row_group_size = compute_row_group_size(arrow_table.schema)
185 pq.write_table(arrow_table, location.path, row_group_size=row_group_size)
188def arrow_to_pandas(arrow_table: pa.Table) -> pd.DataFrame:
189 """Convert a pyarrow table to a pandas DataFrame.
191 Parameters
192 ----------
193 arrow_table : `pyarrow.Table`
194 Input arrow table to convert. If the table has ``pandas`` metadata
195 in the schema it will be used in the construction of the
196 ``DataFrame``.
198 Returns
199 -------
200 dataframe : `pandas.DataFrame`
201 Converted pandas dataframe.
202 """
203 return arrow_table.to_pandas(use_threads=False, integer_object_nulls=True)
206def arrow_to_astropy(arrow_table: pa.Table) -> atable.Table:
207 """Convert a pyarrow table to an `astropy.Table`.
209 Parameters
210 ----------
211 arrow_table : `pyarrow.Table`
212 Input arrow table to convert. If the table has astropy unit
213 metadata in the schema it will be used in the construction
214 of the ``astropy.Table``.
216 Returns
217 -------
218 table : `astropy.Table`
219 Converted astropy table.
220 """
221 from astropy.table import Table
223 astropy_table = Table(arrow_to_numpy_dict(arrow_table))
225 _apply_astropy_metadata(astropy_table, arrow_table.schema)
227 return astropy_table
230def arrow_to_numpy(arrow_table: pa.Table) -> np.ndarray:
231 """Convert a pyarrow table to a structured numpy array.
233 Parameters
234 ----------
235 arrow_table : `pyarrow.Table`
236 Input arrow table.
238 Returns
239 -------
240 array : `numpy.ndarray` (N,)
241 Numpy array table with N rows and the same column names
242 as the input arrow table.
243 """
244 import numpy as np
246 numpy_dict = arrow_to_numpy_dict(arrow_table)
248 dtype = []
249 for name, col in numpy_dict.items():
250 if len(shape := numpy_dict[name].shape) <= 1:
251 dtype.append((name, col.dtype))
252 else:
253 dtype.append((name, (col.dtype, shape[1:])))
255 array = np.rec.fromarrays(numpy_dict.values(), dtype=dtype)
257 return array
260def arrow_to_numpy_dict(arrow_table: pa.Table) -> dict[str, np.ndarray]:
261 """Convert a pyarrow table to a dict of numpy arrays.
263 Parameters
264 ----------
265 arrow_table : `pyarrow.Table`
266 Input arrow table.
268 Returns
269 -------
270 numpy_dict : `dict` [`str`, `numpy.ndarray`]
271 Dict with keys as the column names, values as the arrays.
272 """
273 import numpy as np
275 schema = arrow_table.schema
276 metadata = schema.metadata if schema.metadata is not None else {}
278 numpy_dict = {}
280 for name in schema.names:
281 t = schema.field(name).type
283 if arrow_table[name].null_count == 0:
284 # Regular non-masked column
285 col = arrow_table[name].to_numpy()
286 else:
287 # For a masked column, we need to ask arrow to fill the null
288 # values with an appropriately typed value before conversion.
289 # Then we apply the mask to get a masked array of the correct type.
290 use_masked = True
291 null_value: Any
292 match t:
293 case t if t in (pa.float64(), pa.float32(), pa.float16()):
294 # When filling with nans we do not need to use
295 # the masked array.
296 null_value = np.nan
297 use_masked = False
298 case t if t in (pa.int64(), pa.int32(), pa.int16(), pa.int8()):
299 null_value = -1
300 case t if t in (pa.bool_(),):
301 null_value = True
302 case t if t in (pa.string(), pa.binary()):
303 null_value = ""
304 case _:
305 # This is the fallback for unsigned ints in particular.
306 null_value = 0
308 if use_masked:
309 col = np.ma.masked_array(
310 data=arrow_table[name].fill_null(null_value).to_numpy(),
311 mask=arrow_table[name].is_null().to_numpy(),
312 fill_value=null_value,
313 )
314 else:
315 col = arrow_table[name].fill_null(null_value).to_numpy()
317 if t in (pa.string(), pa.binary()):
318 col = col.astype(_arrow_string_to_numpy_dtype(schema, name, col))
319 elif isinstance(t, pa.FixedSizeListType):
320 if len(col) > 0:
321 col = np.stack(col)
322 else:
323 # this is an empty column, and needs to be coerced to type.
324 col = col.astype(t.value_type.to_pandas_dtype())
326 shape = _multidim_shape_from_metadata(metadata, t.list_size, name)
327 col = col.reshape((len(arrow_table), *shape))
329 numpy_dict[name] = col
331 return numpy_dict
334def _numpy_dict_to_numpy(numpy_dict: dict[str, np.ndarray]) -> np.ndarray:
335 """Convert a dict of numpy arrays to a structured numpy array.
337 Parameters
338 ----------
339 numpy_dict : `dict` [`str`, `numpy.ndarray`]
340 Dict with keys as the column names, values as the arrays.
342 Returns
343 -------
344 array : `numpy.ndarray` (N,)
345 Numpy array table with N rows and columns names from the dict keys.
346 """
347 return arrow_to_numpy(numpy_dict_to_arrow(numpy_dict))
350def _numpy_to_numpy_dict(np_array: np.ndarray) -> dict[str, np.ndarray]:
351 """Convert a structured numpy array to a dict of numpy arrays.
353 Parameters
354 ----------
355 np_array : `numpy.ndarray`
356 Input numpy array with multiple fields.
358 Returns
359 -------
360 numpy_dict : `dict` [`str`, `numpy.ndarray`]
361 Dict with keys as the column names, values as the arrays.
362 """
363 return arrow_to_numpy_dict(numpy_to_arrow(np_array))
366def numpy_to_arrow(np_array: np.ndarray) -> pa.Table:
367 """Convert a numpy array table to an arrow table.
369 Parameters
370 ----------
371 np_array : `numpy.ndarray`
372 Input numpy array with multiple fields.
374 Returns
375 -------
376 arrow_table : `pyarrow.Table`
377 Converted arrow table.
378 """
379 type_list = _numpy_dtype_to_arrow_types(np_array.dtype)
381 md = {}
382 md[b"lsst::arrow::rowcount"] = str(len(np_array))
384 for name in np_array.dtype.names:
385 _append_numpy_string_metadata(md, name, np_array.dtype[name])
386 _append_numpy_multidim_metadata(md, name, np_array.dtype[name])
388 schema = pa.schema(type_list, metadata=md)
390 arrays = _numpy_style_arrays_to_arrow_arrays(
391 np_array.dtype,
392 len(np_array),
393 np_array,
394 schema,
395 )
397 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
399 return arrow_table
402def numpy_dict_to_arrow(numpy_dict: dict[str, np.ndarray]) -> pa.Table:
403 """Convert a dict of numpy arrays to an arrow table.
405 Parameters
406 ----------
407 numpy_dict : `dict` [`str`, `numpy.ndarray`]
408 Dict with keys as the column names, values as the arrays.
410 Returns
411 -------
412 arrow_table : `pyarrow.Table`
413 Converted arrow table.
415 Raises
416 ------
417 ValueError if columns in numpy_dict have unequal numbers of rows.
418 """
419 dtype, rowcount = _numpy_dict_to_dtype(numpy_dict)
420 type_list = _numpy_dtype_to_arrow_types(dtype)
422 md = {}
423 md[b"lsst::arrow::rowcount"] = str(rowcount)
425 if dtype.names is not None:
426 for name in dtype.names:
427 _append_numpy_string_metadata(md, name, dtype[name])
428 _append_numpy_multidim_metadata(md, name, dtype[name])
430 schema = pa.schema(type_list, metadata=md)
432 arrays = _numpy_style_arrays_to_arrow_arrays(
433 dtype,
434 rowcount,
435 numpy_dict,
436 schema,
437 )
439 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
441 return arrow_table
444def astropy_to_arrow(astropy_table: atable.Table) -> pa.Table:
445 """Convert an astropy table to an arrow table.
447 Parameters
448 ----------
449 astropy_table : `astropy.Table`
450 Input astropy table.
452 Returns
453 -------
454 arrow_table : `pyarrow.Table`
455 Converted arrow table.
456 """
457 from astropy.table import meta
459 type_list = _numpy_dtype_to_arrow_types(astropy_table.dtype)
461 md = {}
462 md[b"lsst::arrow::rowcount"] = str(len(astropy_table))
464 for name in astropy_table.dtype.names:
465 _append_numpy_string_metadata(md, name, astropy_table.dtype[name])
466 _append_numpy_multidim_metadata(md, name, astropy_table.dtype[name])
468 meta_yaml = meta.get_yaml_from_table(astropy_table)
469 meta_yaml_str = "\n".join(meta_yaml)
470 md[b"table_meta_yaml"] = meta_yaml_str
472 # Convert type list to fields with metadata.
473 fields = []
474 for name, pa_type in type_list:
475 field_metadata = {}
476 if description := astropy_table[name].description:
477 field_metadata["description"] = description
478 if unit := astropy_table[name].unit:
479 field_metadata["unit"] = str(unit)
480 fields.append(
481 pa.field(
482 name,
483 pa_type,
484 metadata=field_metadata,
485 )
486 )
488 schema = pa.schema(fields, metadata=md)
490 arrays = _numpy_style_arrays_to_arrow_arrays(
491 astropy_table.dtype,
492 len(astropy_table),
493 astropy_table,
494 schema,
495 )
497 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
499 return arrow_table
502def _astropy_to_numpy_dict(astropy_table: atable.Table) -> dict[str, np.ndarray]:
503 """Convert an astropy table to an arrow table.
505 Parameters
506 ----------
507 astropy_table : `astropy.Table`
508 Input astropy table.
510 Returns
511 -------
512 numpy_dict : `dict` [`str`, `numpy.ndarray`]
513 Dict with keys as the column names, values as the arrays.
514 """
515 return arrow_to_numpy_dict(astropy_to_arrow(astropy_table))
518def pandas_to_arrow(dataframe: pd.DataFrame, default_length: int = 10) -> pa.Table:
519 """Convert a pandas dataframe to an arrow table.
521 Parameters
522 ----------
523 dataframe : `pandas.DataFrame`
524 Input pandas dataframe.
525 default_length : `int`, optional
526 Default string length when not in metadata or can be inferred
527 from column.
529 Returns
530 -------
531 arrow_table : `pyarrow.Table`
532 Converted arrow table.
533 """
534 arrow_table = pa.Table.from_pandas(dataframe)
536 # Update the metadata
537 md = arrow_table.schema.metadata
539 md[b"lsst::arrow::rowcount"] = str(arrow_table.num_rows)
541 # We loop through the arrow table columns because the datatypes have
542 # been checked and converted from pandas objects.
543 for name in arrow_table.column_names:
544 if not name.startswith("__") and arrow_table[name].type == pa.string():
545 if len(arrow_table[name]) > 0:
546 strlen = max(len(row.as_py()) for row in arrow_table[name] if row.is_valid)
547 else:
548 strlen = default_length
549 md[f"lsst::arrow::len::{name}".encode()] = str(strlen)
551 arrow_table = arrow_table.replace_schema_metadata(md)
553 return arrow_table
556def pandas_to_astropy(dataframe: pd.DataFrame) -> atable.Table:
557 """Convert a pandas dataframe to an astropy table, preserving indexes.
559 Parameters
560 ----------
561 dataframe : `pandas.DataFrame`
562 Input pandas dataframe.
564 Returns
565 -------
566 astropy_table : `astropy.table.Table`
567 Converted astropy table.
568 """
569 import pandas as pd
570 from astropy.table import Table
572 if isinstance(dataframe.columns, pd.MultiIndex):
573 raise ValueError("Cannot convert a multi-index dataframe to an astropy table.")
575 return Table.from_pandas(dataframe, index=True)
578def _pandas_to_numpy_dict(dataframe: pd.DataFrame) -> dict[str, np.ndarray]:
579 """Convert a pandas dataframe to an dict of numpy arrays.
581 Parameters
582 ----------
583 dataframe : `pandas.DataFrame`
584 Input pandas dataframe.
586 Returns
587 -------
588 numpy_dict : `dict` [`str`, `numpy.ndarray`]
589 Dict with keys as the column names, values as the arrays.
590 """
591 return arrow_to_numpy_dict(pandas_to_arrow(dataframe))
594def numpy_to_astropy(np_array: np.ndarray) -> atable.Table:
595 """Convert a numpy table to an astropy table.
597 Parameters
598 ----------
599 np_array : `numpy.ndarray`
600 Input numpy array with multiple fields.
602 Returns
603 -------
604 astropy_table : `astropy.table.Table`
605 Converted astropy table.
606 """
607 from astropy.table import Table
609 return Table(data=np_array, copy=False)
612def arrow_schema_to_pandas_index(schema: pa.Schema) -> pd.Index | pd.MultiIndex:
613 """Convert an arrow schema to a pandas index/multiindex.
615 Parameters
616 ----------
617 schema : `pyarrow.Schema`
618 Input pyarrow schema.
620 Returns
621 -------
622 index : `pandas.Index` or `pandas.MultiIndex`
623 Converted pandas index.
624 """
625 import pandas as pd
627 if b"pandas" in schema.metadata:
628 md = json.loads(schema.metadata[b"pandas"])
629 indexes = md["column_indexes"]
630 len_indexes = len(indexes)
631 else:
632 len_indexes = 0
634 if len_indexes <= 1:
635 return pd.Index(name for name in schema.names if not name.startswith("__"))
636 else:
637 raw_columns = _split_multi_index_column_names(len(indexes), schema.names)
638 return pd.MultiIndex.from_tuples(raw_columns, names=[f["name"] for f in indexes])
641def arrow_schema_to_column_list(schema: pa.Schema) -> list[str]:
642 """Convert an arrow schema to a list of string column names.
644 Parameters
645 ----------
646 schema : `pyarrow.Schema`
647 Input pyarrow schema.
649 Returns
650 -------
651 column_list : `list` [`str`]
652 Converted list of column names.
653 """
654 return list(schema.names)
657class DataFrameSchema:
658 """Wrapper class for a schema for a pandas DataFrame.
660 Parameters
661 ----------
662 dataframe : `pandas.DataFrame`
663 Dataframe to turn into a schema.
664 """
666 def __init__(self, dataframe: pd.DataFrame) -> None:
667 self._schema = dataframe.loc[[False] * len(dataframe)]
669 @classmethod
670 def from_arrow(cls, schema: pa.Schema) -> DataFrameSchema:
671 """Convert an arrow schema into a `DataFrameSchema`.
673 Parameters
674 ----------
675 schema : `pyarrow.Schema`
676 The pyarrow schema to convert.
678 Returns
679 -------
680 dataframe_schema : `DataFrameSchema`
681 Converted dataframe schema.
682 """
683 empty_table = pa.Table.from_pylist([] * len(schema.names), schema=schema)
685 return cls(empty_table.to_pandas())
687 def to_arrow_schema(self) -> pa.Schema:
688 """Convert to an arrow schema.
690 Returns
691 -------
692 arrow_schema : `pyarrow.Schema`
693 Converted pyarrow schema.
694 """
695 arrow_table = pa.Table.from_pandas(self._schema)
697 return arrow_table.schema
699 def to_arrow_numpy_schema(self) -> ArrowNumpySchema:
700 """Convert to an `ArrowNumpySchema`.
702 Returns
703 -------
704 arrow_numpy_schema : `ArrowNumpySchema`
705 Converted arrow numpy schema.
706 """
707 return ArrowNumpySchema.from_arrow(self.to_arrow_schema())
709 def to_arrow_astropy_schema(self) -> ArrowAstropySchema:
710 """Convert to an ArrowAstropySchema.
712 Returns
713 -------
714 arrow_astropy_schema : `ArrowAstropySchema`
715 Converted arrow astropy schema.
716 """
717 return ArrowAstropySchema.from_arrow(self.to_arrow_schema())
719 @property
720 def schema(self) -> np.dtype:
721 return self._schema
723 def __repr__(self) -> str:
724 return repr(self._schema)
726 def __eq__(self, other: object) -> bool:
727 if not isinstance(other, DataFrameSchema):
728 return NotImplemented
730 return self._schema.equals(other._schema)
733class ArrowAstropySchema:
734 """Wrapper class for a schema for an astropy table.
736 Parameters
737 ----------
738 astropy_table : `astropy.table.Table`
739 Input astropy table.
740 """
742 def __init__(self, astropy_table: atable.Table) -> None:
743 self._schema = astropy_table[:0]
745 @classmethod
746 def from_arrow(cls, schema: pa.Schema) -> ArrowAstropySchema:
747 """Convert an arrow schema into a ArrowAstropySchema.
749 Parameters
750 ----------
751 schema : `pyarrow.Schema`
752 Input pyarrow schema.
754 Returns
755 -------
756 astropy_schema : `ArrowAstropySchema`
757 Converted arrow astropy schema.
758 """
759 import numpy as np
760 from astropy.table import Table
762 dtype = _schema_to_dtype_list(schema)
764 data = np.zeros(0, dtype=dtype)
765 astropy_table = Table(data=data)
767 _apply_astropy_metadata(astropy_table, schema)
769 return cls(astropy_table)
771 def to_arrow_schema(self) -> pa.Schema:
772 """Convert to an arrow schema.
774 Returns
775 -------
776 arrow_schema : `pyarrow.Schema`
777 Converted pyarrow schema.
778 """
779 return astropy_to_arrow(self._schema).schema
781 def to_dataframe_schema(self) -> DataFrameSchema:
782 """Convert to a DataFrameSchema.
784 Returns
785 -------
786 dataframe_schema : `DataFrameSchema`
787 Converted dataframe schema.
788 """
789 return DataFrameSchema.from_arrow(astropy_to_arrow(self._schema).schema)
791 def to_arrow_numpy_schema(self) -> ArrowNumpySchema:
792 """Convert to an `ArrowNumpySchema`.
794 Returns
795 -------
796 arrow_numpy_schema : `ArrowNumpySchema`
797 Converted arrow numpy schema.
798 """
799 return ArrowNumpySchema.from_arrow(astropy_to_arrow(self._schema).schema)
801 @property
802 def schema(self) -> atable.Table:
803 return self._schema
805 def __repr__(self) -> str:
806 return repr(self._schema)
808 def __eq__(self, other: object) -> bool:
809 if not isinstance(other, ArrowAstropySchema):
810 return NotImplemented
812 # If this comparison passes then the two tables have the
813 # same column names.
814 if self._schema.dtype != other._schema.dtype:
815 return False
817 for name in self._schema.columns:
818 if not self._schema[name].unit == other._schema[name].unit:
819 return False
820 if not self._schema[name].description == other._schema[name].description:
821 return False
822 if not self._schema[name].format == other._schema[name].format:
823 return False
825 return True
828class ArrowNumpySchema:
829 """Wrapper class for a schema for a numpy ndarray.
831 Parameters
832 ----------
833 numpy_dtype : `numpy.dtype`
834 Numpy dtype to convert.
835 """
837 def __init__(self, numpy_dtype: np.dtype) -> None:
838 self._dtype = numpy_dtype
840 @classmethod
841 def from_arrow(cls, schema: pa.Schema) -> ArrowNumpySchema:
842 """Convert an arrow schema into an `ArrowNumpySchema`.
844 Parameters
845 ----------
846 schema : `pyarrow.Schema`
847 Pyarrow schema to convert.
849 Returns
850 -------
851 numpy_schema : `ArrowNumpySchema`
852 Converted arrow numpy schema.
853 """
854 import numpy as np
856 dtype = _schema_to_dtype_list(schema)
858 return cls(np.dtype(dtype))
860 def to_arrow_astropy_schema(self) -> ArrowAstropySchema:
861 """Convert to an `ArrowAstropySchema`.
863 Returns
864 -------
865 astropy_schema : `ArrowAstropySchema`
866 Converted arrow astropy schema.
867 """
868 import numpy as np
870 return ArrowAstropySchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema)
872 def to_dataframe_schema(self) -> DataFrameSchema:
873 """Convert to a `DataFrameSchema`.
875 Returns
876 -------
877 dataframe_schema : `DataFrameSchema`
878 Converted dataframe schema.
879 """
880 import numpy as np
882 return DataFrameSchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema)
884 def to_arrow_schema(self) -> pa.Schema:
885 """Convert to a `pyarrow.Schema`.
887 Returns
888 -------
889 arrow_schema : `pyarrow.Schema`
890 Converted pyarrow schema.
891 """
892 import numpy as np
894 return numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema
896 @property
897 def schema(self) -> np.dtype:
898 return self._dtype
900 def __repr__(self) -> str:
901 return repr(self._dtype)
903 def __eq__(self, other: object) -> bool:
904 if not isinstance(other, ArrowNumpySchema):
905 return NotImplemented
907 if not self._dtype == other._dtype:
908 return False
910 return True
913def _split_multi_index_column_names(n: int, names: Iterable[str]) -> list[Sequence[str]]:
914 """Split a string that represents a multi-index column.
916 PyArrow maps Pandas' multi-index column names (which are tuples in Python)
917 to flat strings on disk. This routine exists to reconstruct the original
918 tuple.
920 Parameters
921 ----------
922 n : `int`
923 Number of levels in the `pandas.MultiIndex` that is being
924 reconstructed.
925 names : `~collections.abc.Iterable` [`str`]
926 Strings to be split.
928 Returns
929 -------
930 column_names : `list` [`tuple` [`str`]]
931 A list of multi-index column name tuples.
932 """
933 column_names: list[Sequence[str]] = []
935 pattern = re.compile(r"\({}\)".format(", ".join(["'(.*)'"] * n)))
936 for name in names:
937 m = re.search(pattern, name)
938 if m is not None:
939 column_names.append(m.groups())
941 return column_names
944def _standardize_multi_index_columns(
945 pd_index: pd.MultiIndex,
946 columns: Any,
947 stringify: bool = True,
948) -> list[str | Sequence[Any]]:
949 """Transform a dictionary/iterable index from a multi-index column list
950 into a string directly understandable by PyArrow.
952 Parameters
953 ----------
954 pd_index : `pandas.MultiIndex`
955 Pandas multi-index.
956 columns : `list` [`tuple`] or `dict` [`str`, `str` or `list` [`str`]]
957 Columns to standardize.
958 stringify : `bool`, optional
959 Should the column names be stringified?
961 Returns
962 -------
963 names : `list` [`str`]
964 Stringified representation of a multi-index column name.
965 """
966 index_level_names = tuple(pd_index.names)
968 names: list[str | Sequence[Any]] = []
970 if isinstance(columns, list):
971 for requested in columns:
972 if not isinstance(requested, tuple):
973 raise ValueError(
974 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. "
975 f"Instead got a {get_full_type_name(requested)}."
976 )
977 if stringify:
978 names.append(str(requested))
979 else:
980 names.append(requested)
981 else:
982 if not isinstance(columns, collections.abc.Mapping):
983 raise ValueError(
984 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. "
985 f"Instead got a {get_full_type_name(columns)}."
986 )
987 if not set(index_level_names).issuperset(columns.keys()):
988 raise ValueError(
989 f"Cannot use dict with keys {set(columns.keys())} to select columns from {index_level_names}."
990 )
991 factors = [
992 ensure_iterable(columns.get(level, pd_index.levels[i]))
993 for i, level in enumerate(index_level_names)
994 ]
995 for requested in itertools.product(*factors):
996 for i, value in enumerate(requested):
997 if value not in pd_index.levels[i]:
998 raise ValueError(f"Unrecognized value {value!r} for index {index_level_names[i]!r}.")
999 if stringify:
1000 names.append(str(requested))
1001 else:
1002 names.append(requested)
1004 return names
1007def _apply_astropy_metadata(astropy_table: atable.Table, arrow_schema: pa.Schema) -> None:
1008 """Apply any astropy metadata from the schema metadata.
1010 Parameters
1011 ----------
1012 astropy_table : `astropy.table.Table`
1013 Table to apply metadata.
1014 arrow_schema : `pyarrow.Schema`
1015 Arrow schema with metadata.
1016 """
1017 from astropy.table import meta
1019 metadata = arrow_schema.metadata if arrow_schema.metadata is not None else {}
1021 # Check if we have a special astropy metadata header yaml.
1022 meta_yaml = metadata.get(b"table_meta_yaml", None)
1023 if meta_yaml:
1024 meta_yaml = meta_yaml.decode("UTF8").split("\n")
1025 meta_hdr = meta.get_header_from_yaml(meta_yaml)
1027 # Set description, format, unit, meta from the column
1028 # metadata that was serialized with the table.
1029 header_cols = {x["name"]: x for x in meta_hdr["datatype"]}
1030 for col in astropy_table.columns.values():
1031 for attr in ("description", "format", "unit", "meta"):
1032 if attr in header_cols[col.name]:
1033 setattr(col, attr, header_cols[col.name][attr])
1035 if "meta" in meta_hdr:
1036 astropy_table.meta.update(meta_hdr["meta"])
1037 else:
1038 # If we don't have astropy header data, we may have arrow field
1039 # metadata.
1040 for name in arrow_schema.names:
1041 field_metadata = arrow_schema.field(name).metadata
1042 if field_metadata is None:
1043 continue
1044 if (
1045 b"description" in field_metadata
1046 and (description := field_metadata[b"description"].decode("UTF-8")) != ""
1047 ):
1048 astropy_table[name].description = description
1049 if b"unit" in field_metadata and (unit := field_metadata[b"unit"].decode("UTF-8")) != "":
1050 astropy_table[name].unit = unit
1053def _arrow_string_to_numpy_dtype(
1054 schema: pa.Schema, name: str, numpy_column: np.ndarray | None = None, default_length: int = 10
1055) -> str:
1056 """Get the numpy dtype string associated with an arrow column.
1058 Parameters
1059 ----------
1060 schema : `pyarrow.Schema`
1061 Arrow table schema.
1062 name : `str`
1063 Column name.
1064 numpy_column : `numpy.ndarray`, optional
1065 Column to determine numpy string dtype.
1066 default_length : `int`, optional
1067 Default string length when not in metadata or can be inferred
1068 from column.
1070 Returns
1071 -------
1072 dtype_str : `str`
1073 Numpy dtype string.
1074 """
1075 # Special-case for string and binary columns
1076 md_name = f"lsst::arrow::len::{name}"
1077 strlen = default_length
1078 metadata = schema.metadata if schema.metadata is not None else {}
1079 if (encoded := md_name.encode("UTF-8")) in metadata:
1080 # String/bytes length from header.
1081 strlen = int(schema.metadata[encoded])
1082 elif numpy_column is not None and len(numpy_column) > 0:
1083 strlen = max(len(row) for row in numpy_column)
1085 dtype = f"U{strlen}" if schema.field(name).type == pa.string() else f"|S{strlen}"
1087 return dtype
1090def _append_numpy_string_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None:
1091 """Append numpy string length keys to arrow metadata.
1093 All column types are handled, but the metadata is only modified for
1094 string and byte columns.
1096 Parameters
1097 ----------
1098 metadata : `dict` [`bytes`, `str`]
1099 Metadata dictionary; modified in place.
1100 name : `str`
1101 Column name.
1102 dtype : `np.dtype`
1103 Numpy dtype.
1104 """
1105 import numpy as np
1107 if dtype.type is np.str_:
1108 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize // 4)
1109 elif dtype.type is np.bytes_:
1110 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize)
1113def _append_numpy_multidim_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None:
1114 """Append numpy multi-dimensional shapes to arrow metadata.
1116 All column types are handled, but the metadata is only modified for
1117 multi-dimensional columns.
1119 Parameters
1120 ----------
1121 metadata : `dict` [`bytes`, `str`]
1122 Metadata dictionary; modified in place.
1123 name : `str`
1124 Column name.
1125 dtype : `np.dtype`
1126 Numpy dtype.
1127 """
1128 if len(dtype.shape) > 1:
1129 metadata[f"lsst::arrow::shape::{name}".encode()] = str(dtype.shape)
1132def _multidim_shape_from_metadata(metadata: dict[bytes, bytes], list_size: int, name: str) -> tuple[int, ...]:
1133 """Retrieve the shape from the metadata, if available.
1135 Parameters
1136 ----------
1137 metadata : `dict` [`bytes`, `bytes`]
1138 Metadata dictionary.
1139 list_size : `int`
1140 Size of the list datatype.
1141 name : `str`
1142 Column name.
1144 Returns
1145 -------
1146 shape : `tuple` [`int`]
1147 Shape associated with the column.
1149 Raises
1150 ------
1151 RuntimeError
1152 Raised if metadata is found but has incorrect format.
1153 """
1154 md_name = f"lsst::arrow::shape::{name}"
1155 if (encoded := md_name.encode("UTF-8")) in metadata:
1156 groups = re.search(r"\((.*)\)", metadata[encoded].decode("UTF-8"))
1157 if groups is None:
1158 raise RuntimeError("Illegal value found in metadata.")
1159 shape = tuple(int(x) for x in groups[1].split(",") if x != "")
1160 else:
1161 shape = (list_size,)
1163 return shape
1166def _schema_to_dtype_list(schema: pa.Schema) -> list[tuple[str, tuple[Any] | str]]:
1167 """Convert a pyarrow schema to a numpy dtype.
1169 Parameters
1170 ----------
1171 schema : `pyarrow.Schema`
1172 Input pyarrow schema.
1174 Returns
1175 -------
1176 dtype_list: `list` [`tuple`]
1177 A list with name, type pairs.
1178 """
1179 metadata = schema.metadata if schema.metadata is not None else {}
1181 dtype: list[Any] = []
1182 for name in schema.names:
1183 t = schema.field(name).type
1184 if isinstance(t, pa.FixedSizeListType):
1185 shape = _multidim_shape_from_metadata(metadata, t.list_size, name)
1186 dtype.append((name, (t.value_type.to_pandas_dtype(), shape)))
1187 elif t not in (pa.string(), pa.binary()):
1188 dtype.append((name, t.to_pandas_dtype()))
1189 else:
1190 dtype.append((name, _arrow_string_to_numpy_dtype(schema, name)))
1192 return dtype
1195def _numpy_dtype_to_arrow_types(dtype: np.dtype) -> list[Any]:
1196 """Convert a numpy dtype to a list of arrow types.
1198 Parameters
1199 ----------
1200 dtype : `numpy.dtype`
1201 Numpy dtype to convert.
1203 Returns
1204 -------
1205 type_list : `list` [`object`]
1206 Converted list of arrow types.
1207 """
1208 from math import prod
1210 import numpy as np
1212 type_list: list[Any] = []
1213 if dtype.names is None:
1214 return type_list
1216 for name in dtype.names:
1217 dt = dtype[name]
1218 arrow_type: Any
1219 if len(dt.shape) > 0:
1220 arrow_type = pa.list_(
1221 pa.from_numpy_dtype(cast(tuple[np.dtype, tuple[int, ...]], dt.subdtype)[0].type),
1222 prod(dt.shape),
1223 )
1224 else:
1225 arrow_type = pa.from_numpy_dtype(dt.type)
1226 type_list.append((name, arrow_type))
1228 return type_list
1231def _numpy_dict_to_dtype(numpy_dict: dict[str, np.ndarray]) -> tuple[np.dtype, int]:
1232 """Extract equivalent table dtype from dict of numpy arrays.
1234 Parameters
1235 ----------
1236 numpy_dict : `dict` [`str`, `numpy.ndarray`]
1237 Dict with keys as the column names, values as the arrays.
1239 Returns
1240 -------
1241 dtype : `numpy.dtype`
1242 dtype of equivalent table.
1243 rowcount : `int`
1244 Number of rows in the table.
1246 Raises
1247 ------
1248 ValueError if columns in numpy_dict have unequal numbers of rows.
1249 """
1250 import numpy as np
1252 dtype_list = []
1253 rowcount = 0
1254 for name, col in numpy_dict.items():
1255 if rowcount == 0:
1256 rowcount = len(col)
1257 if len(col) != rowcount:
1258 raise ValueError(f"Column {name} has a different number of rows.")
1259 if len(col.shape) == 1:
1260 dtype_list.append((name, col.dtype))
1261 else:
1262 dtype_list.append((name, (col.dtype, col.shape[1:])))
1263 dtype = np.dtype(dtype_list)
1265 return (dtype, rowcount)
1268def _numpy_style_arrays_to_arrow_arrays(
1269 dtype: np.dtype,
1270 rowcount: int,
1271 np_style_arrays: dict[str, np.ndarray] | np.ndarray | atable.Table,
1272 schema: pa.Schema,
1273) -> list[pa.Array]:
1274 """Convert numpy-style arrays to arrow arrays.
1276 Parameters
1277 ----------
1278 dtype : `numpy.dtype`
1279 Numpy dtype of input table/arrays.
1280 rowcount : `int`
1281 Number of rows in input table/arrays.
1282 np_style_arrays : `dict` [`str`, `np.ndarray`] or `np.ndarray`
1283 or `astropy.table.Table`
1284 Arrays to convert to arrow.
1285 schema : `pyarrow.Schema`
1286 Schema of arrow table.
1288 Returns
1289 -------
1290 arrow_arrays : `list` [`pyarrow.Array`]
1291 List of converted pyarrow arrays.
1292 """
1293 import numpy as np
1295 arrow_arrays: list[pa.Array] = []
1296 if dtype.names is None:
1297 return arrow_arrays
1299 for name in dtype.names:
1300 dt = dtype[name]
1301 val: Any
1302 if len(dt.shape) > 0:
1303 if rowcount > 0:
1304 val = np.split(np_style_arrays[name].ravel(), rowcount)
1305 else:
1306 val = []
1307 else:
1308 val = np_style_arrays[name]
1310 try:
1311 arrow_arrays.append(pa.array(val, type=schema.field(name).type))
1312 except pa.ArrowNotImplementedError as err:
1313 # Check if val is big-endian.
1314 if (np.little_endian and val.dtype.byteorder == ">") or (
1315 not np.little_endian and val.dtype.byteorder == "="
1316 ):
1317 # We need to convert the array to little-endian.
1318 val2 = val.byteswap()
1319 val2.dtype = val2.dtype.newbyteorder("<")
1320 arrow_arrays.append(pa.array(val2, type=schema.field(name).type))
1321 else:
1322 # This failed for some other reason so raise the exception.
1323 raise err
1325 return arrow_arrays
1328def compute_row_group_size(schema: pa.Schema, target_size: int = TARGET_ROW_GROUP_BYTES) -> int:
1329 """Compute approximate row group size for a given arrow schema.
1331 Given a schema, this routine will compute the number of rows in a row group
1332 that targets the persisted size on disk (or smaller). The exact size on
1333 disk depends on the compression settings and ratios; typical binary data
1334 tables will have around 15-20% compression with the pyarrow default
1335 ``snappy`` compression algorithm.
1337 Parameters
1338 ----------
1339 schema : `pyarrow.Schema`
1340 Arrow table schema.
1341 target_size : `int`, optional
1342 The target size (in bytes).
1344 Returns
1345 -------
1346 row_group_size : `int`
1347 Number of rows per row group to hit the target size.
1348 """
1349 bit_width = 0
1351 metadata = schema.metadata if schema.metadata is not None else {}
1353 for name in schema.names:
1354 t = schema.field(name).type
1356 if t in (pa.string(), pa.binary()):
1357 md_name = f"lsst::arrow::len::{name}"
1359 if (encoded := md_name.encode("UTF-8")) in metadata:
1360 # String/bytes length from header.
1361 strlen = int(schema.metadata[encoded])
1362 else:
1363 # We don't know the string width, so guess something.
1364 strlen = 10
1366 # Assuming UTF-8 encoding, and very few wide characters.
1367 t_width = 8 * strlen
1368 elif isinstance(t, pa.FixedSizeListType):
1369 if t.value_type == pa.null():
1370 t_width = 0
1371 else:
1372 t_width = t.list_size * t.value_type.bit_width
1373 elif t == pa.null():
1374 t_width = 0
1375 elif isinstance(t, pa.ListType):
1376 if t.value_type == pa.null():
1377 t_width = 0
1378 else:
1379 # This is a variable length list, just choose
1380 # something arbitrary.
1381 t_width = 10 * t.value_type.bit_width
1382 else:
1383 t_width = t.bit_width
1385 bit_width += t_width
1387 # Insist it is at least 1 byte wide to avoid any divide-by-zero errors.
1388 if bit_width < 8:
1389 bit_width = 8
1391 byte_width = bit_width // 8
1393 return target_size // byte_width