Coverage for python/lsst/daf/butler/formatters/parquet.py: 13%
461 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-07 11:04 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-07 11:04 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = (
31 "ParquetFormatter",
32 "arrow_to_pandas",
33 "arrow_to_astropy",
34 "arrow_to_numpy",
35 "arrow_to_numpy_dict",
36 "pandas_to_arrow",
37 "pandas_to_astropy",
38 "astropy_to_arrow",
39 "numpy_to_arrow",
40 "numpy_to_astropy",
41 "numpy_dict_to_arrow",
42 "arrow_schema_to_pandas_index",
43 "DataFrameSchema",
44 "ArrowAstropySchema",
45 "ArrowNumpySchema",
46 "compute_row_group_size",
47)
49import collections.abc
50import itertools
51import json
52import re
53from collections.abc import Iterable, Sequence
54from typing import TYPE_CHECKING, Any, cast
56import pyarrow as pa
57import pyarrow.parquet as pq
58from lsst.daf.butler import Formatter
59from lsst.utils.introspection import get_full_type_name
60from lsst.utils.iteration import ensure_iterable
62if TYPE_CHECKING:
63 import astropy.table as atable
64 import numpy as np
65 import pandas as pd
67TARGET_ROW_GROUP_BYTES = 1_000_000_000
70class ParquetFormatter(Formatter):
71 """Interface for reading and writing Arrow Table objects to and from
72 Parquet files.
73 """
75 extension = ".parq"
77 def read(self, component: str | None = None) -> Any:
78 # Docstring inherited from Formatter.read.
79 schema = pq.read_schema(self.fileDescriptor.location.path)
81 schema_names = ["ArrowSchema", "DataFrameSchema", "ArrowAstropySchema", "ArrowNumpySchema"]
83 if component in ("columns", "schema") or self.fileDescriptor.readStorageClass.name in schema_names:
84 # The schema will be translated to column format
85 # depending on the input type.
86 return schema
87 elif component == "rowcount":
88 # Get the rowcount from the metadata if possible, otherwise count.
89 if b"lsst::arrow::rowcount" in schema.metadata:
90 return int(schema.metadata[b"lsst::arrow::rowcount"])
92 temp_table = pq.read_table(
93 self.fileDescriptor.location.path,
94 columns=[schema.names[0]],
95 use_threads=False,
96 use_pandas_metadata=False,
97 )
99 return len(temp_table[schema.names[0]])
101 par_columns = None
102 if self.fileDescriptor.parameters:
103 par_columns = self.fileDescriptor.parameters.pop("columns", None)
104 if par_columns:
105 has_pandas_multi_index = False
106 if b"pandas" in schema.metadata:
107 md = json.loads(schema.metadata[b"pandas"])
108 if len(md["column_indexes"]) > 1:
109 has_pandas_multi_index = True
111 if not has_pandas_multi_index:
112 # Ensure uniqueness, keeping order.
113 par_columns = list(dict.fromkeys(ensure_iterable(par_columns)))
114 file_columns = [name for name in schema.names if not name.startswith("__")]
116 for par_column in par_columns:
117 if par_column not in file_columns:
118 raise ValueError(
119 f"Column {par_column} specified in parameters not available in parquet file."
120 )
121 else:
122 par_columns = _standardize_multi_index_columns(
123 arrow_schema_to_pandas_index(schema),
124 par_columns,
125 )
127 if len(self.fileDescriptor.parameters):
128 raise ValueError(
129 f"Unsupported parameters {self.fileDescriptor.parameters} in ArrowTable read."
130 )
132 metadata = schema.metadata if schema.metadata is not None else {}
133 arrow_table = pq.read_table(
134 self.fileDescriptor.location.path,
135 columns=par_columns,
136 use_threads=False,
137 use_pandas_metadata=(b"pandas" in metadata),
138 )
140 return arrow_table
142 def write(self, inMemoryDataset: Any) -> None:
143 import numpy as np
144 from astropy.table import Table as astropyTable
146 location = self.makeUpdatedLocation(self.fileDescriptor.location)
148 arrow_table = None
149 if isinstance(inMemoryDataset, pa.Table):
150 # This will be the most likely match.
151 arrow_table = inMemoryDataset
152 elif isinstance(inMemoryDataset, astropyTable):
153 arrow_table = astropy_to_arrow(inMemoryDataset)
154 elif isinstance(inMemoryDataset, np.ndarray):
155 arrow_table = numpy_to_arrow(inMemoryDataset)
156 elif isinstance(inMemoryDataset, dict):
157 try:
158 arrow_table = numpy_dict_to_arrow(inMemoryDataset)
159 except (TypeError, AttributeError) as e:
160 raise ValueError(
161 "Input dict for inMemoryDataset does not appear to be a dict of numpy arrays."
162 ) from e
163 elif isinstance(inMemoryDataset, pa.Schema):
164 pq.write_metadata(inMemoryDataset, location.path)
165 return
166 else:
167 if hasattr(inMemoryDataset, "to_parquet"):
168 # This may be a pandas DataFrame
169 try:
170 import pandas as pd
171 except ImportError:
172 pd = None
174 if pd is not None and isinstance(inMemoryDataset, pd.DataFrame):
175 arrow_table = pandas_to_arrow(inMemoryDataset)
177 if arrow_table is None:
178 raise ValueError(
179 f"Unsupported type {get_full_type_name(inMemoryDataset)} of "
180 "inMemoryDataset for ParquetFormatter."
181 )
183 row_group_size = compute_row_group_size(arrow_table.schema)
185 pq.write_table(arrow_table, location.path, row_group_size=row_group_size)
188def arrow_to_pandas(arrow_table: pa.Table) -> pd.DataFrame:
189 """Convert a pyarrow table to a pandas DataFrame.
191 Parameters
192 ----------
193 arrow_table : `pyarrow.Table`
194 Input arrow table to convert. If the table has ``pandas`` metadata
195 in the schema it will be used in the construction of the
196 ``DataFrame``.
198 Returns
199 -------
200 dataframe : `pandas.DataFrame`
201 Converted pandas dataframe.
202 """
203 return arrow_table.to_pandas(use_threads=False, integer_object_nulls=True)
206def arrow_to_astropy(arrow_table: pa.Table) -> atable.Table:
207 """Convert a pyarrow table to an `astropy.Table`.
209 Parameters
210 ----------
211 arrow_table : `pyarrow.Table`
212 Input arrow table to convert. If the table has astropy unit
213 metadata in the schema it will be used in the construction
214 of the ``astropy.Table``.
216 Returns
217 -------
218 table : `astropy.Table`
219 Converted astropy table.
220 """
221 from astropy.table import Table
223 astropy_table = Table(arrow_to_numpy_dict(arrow_table))
225 _apply_astropy_metadata(astropy_table, arrow_table.schema)
227 return astropy_table
230def arrow_to_numpy(arrow_table: pa.Table) -> np.ndarray:
231 """Convert a pyarrow table to a structured numpy array.
233 Parameters
234 ----------
235 arrow_table : `pyarrow.Table`
236 Input arrow table.
238 Returns
239 -------
240 array : `numpy.ndarray` (N,)
241 Numpy array table with N rows and the same column names
242 as the input arrow table.
243 """
244 import numpy as np
246 numpy_dict = arrow_to_numpy_dict(arrow_table)
248 dtype = []
249 for name, col in numpy_dict.items():
250 if len(shape := numpy_dict[name].shape) <= 1:
251 dtype.append((name, col.dtype))
252 else:
253 dtype.append((name, (col.dtype, shape[1:])))
255 array = np.rec.fromarrays(numpy_dict.values(), dtype=dtype)
257 return array
260def arrow_to_numpy_dict(arrow_table: pa.Table) -> dict[str, np.ndarray]:
261 """Convert a pyarrow table to a dict of numpy arrays.
263 Parameters
264 ----------
265 arrow_table : `pyarrow.Table`
266 Input arrow table.
268 Returns
269 -------
270 numpy_dict : `dict` [`str`, `numpy.ndarray`]
271 Dict with keys as the column names, values as the arrays.
272 """
273 import numpy as np
275 schema = arrow_table.schema
276 metadata = schema.metadata if schema.metadata is not None else {}
278 numpy_dict = {}
280 for name in schema.names:
281 t = schema.field(name).type
283 if arrow_table[name].null_count == 0:
284 # Regular non-masked column
285 col = arrow_table[name].to_numpy()
286 else:
287 # For a masked column, we need to ask arrow to fill the null
288 # values with an appropriately typed value before conversion.
289 # Then we apply the mask to get a masked array of the correct type.
291 if t in (pa.string(), pa.binary()):
292 dummy = ""
293 else:
294 dummy = t.to_pandas_dtype()(0)
296 col = np.ma.masked_array(
297 data=arrow_table[name].fill_null(dummy).to_numpy(),
298 mask=arrow_table[name].is_null().to_numpy(),
299 )
301 if t in (pa.string(), pa.binary()):
302 col = col.astype(_arrow_string_to_numpy_dtype(schema, name, col))
303 elif isinstance(t, pa.FixedSizeListType):
304 if len(col) > 0:
305 col = np.stack(col)
306 else:
307 # this is an empty column, and needs to be coerced to type.
308 col = col.astype(t.value_type.to_pandas_dtype())
310 shape = _multidim_shape_from_metadata(metadata, t.list_size, name)
311 col = col.reshape((len(arrow_table), *shape))
313 numpy_dict[name] = col
315 return numpy_dict
318def _numpy_dict_to_numpy(numpy_dict: dict[str, np.ndarray]) -> np.ndarray:
319 """Convert a dict of numpy arrays to a structured numpy array.
321 Parameters
322 ----------
323 numpy_dict : `dict` [`str`, `numpy.ndarray`]
324 Dict with keys as the column names, values as the arrays.
326 Returns
327 -------
328 array : `numpy.ndarray` (N,)
329 Numpy array table with N rows and columns names from the dict keys.
330 """
331 return arrow_to_numpy(numpy_dict_to_arrow(numpy_dict))
334def _numpy_to_numpy_dict(np_array: np.ndarray) -> dict[str, np.ndarray]:
335 """Convert a structured numpy array to a dict of numpy arrays.
337 Parameters
338 ----------
339 np_array : `numpy.ndarray`
340 Input numpy array with multiple fields.
342 Returns
343 -------
344 numpy_dict : `dict` [`str`, `numpy.ndarray`]
345 Dict with keys as the column names, values as the arrays.
346 """
347 return arrow_to_numpy_dict(numpy_to_arrow(np_array))
350def numpy_to_arrow(np_array: np.ndarray) -> pa.Table:
351 """Convert a numpy array table to an arrow table.
353 Parameters
354 ----------
355 np_array : `numpy.ndarray`
356 Input numpy array with multiple fields.
358 Returns
359 -------
360 arrow_table : `pyarrow.Table`
361 Converted arrow table.
362 """
363 type_list = _numpy_dtype_to_arrow_types(np_array.dtype)
365 md = {}
366 md[b"lsst::arrow::rowcount"] = str(len(np_array))
368 for name in np_array.dtype.names:
369 _append_numpy_string_metadata(md, name, np_array.dtype[name])
370 _append_numpy_multidim_metadata(md, name, np_array.dtype[name])
372 schema = pa.schema(type_list, metadata=md)
374 arrays = _numpy_style_arrays_to_arrow_arrays(
375 np_array.dtype,
376 len(np_array),
377 np_array,
378 schema,
379 )
381 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
383 return arrow_table
386def numpy_dict_to_arrow(numpy_dict: dict[str, np.ndarray]) -> pa.Table:
387 """Convert a dict of numpy arrays to an arrow table.
389 Parameters
390 ----------
391 numpy_dict : `dict` [`str`, `numpy.ndarray`]
392 Dict with keys as the column names, values as the arrays.
394 Returns
395 -------
396 arrow_table : `pyarrow.Table`
397 Converted arrow table.
399 Raises
400 ------
401 ValueError if columns in numpy_dict have unequal numbers of rows.
402 """
403 dtype, rowcount = _numpy_dict_to_dtype(numpy_dict)
404 type_list = _numpy_dtype_to_arrow_types(dtype)
406 md = {}
407 md[b"lsst::arrow::rowcount"] = str(rowcount)
409 if dtype.names is not None:
410 for name in dtype.names:
411 _append_numpy_string_metadata(md, name, dtype[name])
412 _append_numpy_multidim_metadata(md, name, dtype[name])
414 schema = pa.schema(type_list, metadata=md)
416 arrays = _numpy_style_arrays_to_arrow_arrays(
417 dtype,
418 rowcount,
419 numpy_dict,
420 schema,
421 )
423 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
425 return arrow_table
428def astropy_to_arrow(astropy_table: atable.Table) -> pa.Table:
429 """Convert an astropy table to an arrow table.
431 Parameters
432 ----------
433 astropy_table : `astropy.Table`
434 Input astropy table.
436 Returns
437 -------
438 arrow_table : `pyarrow.Table`
439 Converted arrow table.
440 """
441 from astropy.table import meta
443 type_list = _numpy_dtype_to_arrow_types(astropy_table.dtype)
445 md = {}
446 md[b"lsst::arrow::rowcount"] = str(len(astropy_table))
448 for name in astropy_table.dtype.names:
449 _append_numpy_string_metadata(md, name, astropy_table.dtype[name])
450 _append_numpy_multidim_metadata(md, name, astropy_table.dtype[name])
452 meta_yaml = meta.get_yaml_from_table(astropy_table)
453 meta_yaml_str = "\n".join(meta_yaml)
454 md[b"table_meta_yaml"] = meta_yaml_str
456 # Convert type list to fields with metadata.
457 fields = []
458 for name, pa_type in type_list:
459 field_metadata = {}
460 if description := astropy_table[name].description:
461 field_metadata["description"] = description
462 if unit := astropy_table[name].unit:
463 field_metadata["unit"] = str(unit)
464 fields.append(
465 pa.field(
466 name,
467 pa_type,
468 metadata=field_metadata,
469 )
470 )
472 schema = pa.schema(fields, metadata=md)
474 arrays = _numpy_style_arrays_to_arrow_arrays(
475 astropy_table.dtype,
476 len(astropy_table),
477 astropy_table,
478 schema,
479 )
481 arrow_table = pa.Table.from_arrays(arrays, schema=schema)
483 return arrow_table
486def _astropy_to_numpy_dict(astropy_table: atable.Table) -> dict[str, np.ndarray]:
487 """Convert an astropy table to an arrow table.
489 Parameters
490 ----------
491 astropy_table : `astropy.Table`
492 Input astropy table.
494 Returns
495 -------
496 numpy_dict : `dict` [`str`, `numpy.ndarray`]
497 Dict with keys as the column names, values as the arrays.
498 """
499 return arrow_to_numpy_dict(astropy_to_arrow(astropy_table))
502def pandas_to_arrow(dataframe: pd.DataFrame, default_length: int = 10) -> pa.Table:
503 """Convert a pandas dataframe to an arrow table.
505 Parameters
506 ----------
507 dataframe : `pandas.DataFrame`
508 Input pandas dataframe.
509 default_length : `int`, optional
510 Default string length when not in metadata or can be inferred
511 from column.
513 Returns
514 -------
515 arrow_table : `pyarrow.Table`
516 Converted arrow table.
517 """
518 arrow_table = pa.Table.from_pandas(dataframe)
520 # Update the metadata
521 md = arrow_table.schema.metadata
523 md[b"lsst::arrow::rowcount"] = str(arrow_table.num_rows)
525 # We loop through the arrow table columns because the datatypes have
526 # been checked and converted from pandas objects.
527 for name in arrow_table.column_names:
528 if not name.startswith("__") and arrow_table[name].type == pa.string():
529 if len(arrow_table[name]) > 0:
530 strlen = max(len(row.as_py()) for row in arrow_table[name] if row.is_valid)
531 else:
532 strlen = default_length
533 md[f"lsst::arrow::len::{name}".encode()] = str(strlen)
535 arrow_table = arrow_table.replace_schema_metadata(md)
537 return arrow_table
540def pandas_to_astropy(dataframe: pd.DataFrame) -> atable.Table:
541 """Convert a pandas dataframe to an astropy table, preserving indexes.
543 Parameters
544 ----------
545 dataframe : `pandas.DataFrame`
546 Input pandas dataframe.
548 Returns
549 -------
550 astropy_table : `astropy.table.Table`
551 Converted astropy table.
552 """
553 import pandas as pd
554 from astropy.table import Table
556 if isinstance(dataframe.columns, pd.MultiIndex):
557 raise ValueError("Cannot convert a multi-index dataframe to an astropy table.")
559 return Table.from_pandas(dataframe, index=True)
562def _pandas_to_numpy_dict(dataframe: pd.DataFrame) -> dict[str, np.ndarray]:
563 """Convert a pandas dataframe to an dict of numpy arrays.
565 Parameters
566 ----------
567 dataframe : `pandas.DataFrame`
568 Input pandas dataframe.
570 Returns
571 -------
572 numpy_dict : `dict` [`str`, `numpy.ndarray`]
573 Dict with keys as the column names, values as the arrays.
574 """
575 return arrow_to_numpy_dict(pandas_to_arrow(dataframe))
578def numpy_to_astropy(np_array: np.ndarray) -> atable.Table:
579 """Convert a numpy table to an astropy table.
581 Parameters
582 ----------
583 np_array : `numpy.ndarray`
584 Input numpy array with multiple fields.
586 Returns
587 -------
588 astropy_table : `astropy.table.Table`
589 Converted astropy table.
590 """
591 from astropy.table import Table
593 return Table(data=np_array, copy=False)
596def arrow_schema_to_pandas_index(schema: pa.Schema) -> pd.Index | pd.MultiIndex:
597 """Convert an arrow schema to a pandas index/multiindex.
599 Parameters
600 ----------
601 schema : `pyarrow.Schema`
602 Input pyarrow schema.
604 Returns
605 -------
606 index : `pandas.Index` or `pandas.MultiIndex`
607 Converted pandas index.
608 """
609 import pandas as pd
611 if b"pandas" in schema.metadata:
612 md = json.loads(schema.metadata[b"pandas"])
613 indexes = md["column_indexes"]
614 len_indexes = len(indexes)
615 else:
616 len_indexes = 0
618 if len_indexes <= 1:
619 return pd.Index(name for name in schema.names if not name.startswith("__"))
620 else:
621 raw_columns = _split_multi_index_column_names(len(indexes), schema.names)
622 return pd.MultiIndex.from_tuples(raw_columns, names=[f["name"] for f in indexes])
625def arrow_schema_to_column_list(schema: pa.Schema) -> list[str]:
626 """Convert an arrow schema to a list of string column names.
628 Parameters
629 ----------
630 schema : `pyarrow.Schema`
631 Input pyarrow schema.
633 Returns
634 -------
635 column_list : `list` [`str`]
636 Converted list of column names.
637 """
638 return list(schema.names)
641class DataFrameSchema:
642 """Wrapper class for a schema for a pandas DataFrame.
644 Parameters
645 ----------
646 dataframe : `pandas.DataFrame`
647 Dataframe to turn into a schema.
648 """
650 def __init__(self, dataframe: pd.DataFrame) -> None:
651 self._schema = dataframe.loc[[False] * len(dataframe)]
653 @classmethod
654 def from_arrow(cls, schema: pa.Schema) -> DataFrameSchema:
655 """Convert an arrow schema into a `DataFrameSchema`.
657 Parameters
658 ----------
659 schema : `pyarrow.Schema`
660 The pyarrow schema to convert.
662 Returns
663 -------
664 dataframe_schema : `DataFrameSchema`
665 Converted dataframe schema.
666 """
667 empty_table = pa.Table.from_pylist([] * len(schema.names), schema=schema)
669 return cls(empty_table.to_pandas())
671 def to_arrow_schema(self) -> pa.Schema:
672 """Convert to an arrow schema.
674 Returns
675 -------
676 arrow_schema : `pyarrow.Schema`
677 Converted pyarrow schema.
678 """
679 arrow_table = pa.Table.from_pandas(self._schema)
681 return arrow_table.schema
683 def to_arrow_numpy_schema(self) -> ArrowNumpySchema:
684 """Convert to an `ArrowNumpySchema`.
686 Returns
687 -------
688 arrow_numpy_schema : `ArrowNumpySchema`
689 Converted arrow numpy schema.
690 """
691 return ArrowNumpySchema.from_arrow(self.to_arrow_schema())
693 def to_arrow_astropy_schema(self) -> ArrowAstropySchema:
694 """Convert to an ArrowAstropySchema.
696 Returns
697 -------
698 arrow_astropy_schema : `ArrowAstropySchema`
699 Converted arrow astropy schema.
700 """
701 return ArrowAstropySchema.from_arrow(self.to_arrow_schema())
703 @property
704 def schema(self) -> np.dtype:
705 return self._schema
707 def __repr__(self) -> str:
708 return repr(self._schema)
710 def __eq__(self, other: object) -> bool:
711 if not isinstance(other, DataFrameSchema):
712 return NotImplemented
714 return self._schema.equals(other._schema)
717class ArrowAstropySchema:
718 """Wrapper class for a schema for an astropy table.
720 Parameters
721 ----------
722 astropy_table : `astropy.table.Table`
723 Input astropy table.
724 """
726 def __init__(self, astropy_table: atable.Table) -> None:
727 self._schema = astropy_table[:0]
729 @classmethod
730 def from_arrow(cls, schema: pa.Schema) -> ArrowAstropySchema:
731 """Convert an arrow schema into a ArrowAstropySchema.
733 Parameters
734 ----------
735 schema : `pyarrow.Schema`
736 Input pyarrow schema.
738 Returns
739 -------
740 astropy_schema : `ArrowAstropySchema`
741 Converted arrow astropy schema.
742 """
743 import numpy as np
744 from astropy.table import Table
746 dtype = _schema_to_dtype_list(schema)
748 data = np.zeros(0, dtype=dtype)
749 astropy_table = Table(data=data)
751 _apply_astropy_metadata(astropy_table, schema)
753 return cls(astropy_table)
755 def to_arrow_schema(self) -> pa.Schema:
756 """Convert to an arrow schema.
758 Returns
759 -------
760 arrow_schema : `pyarrow.Schema`
761 Converted pyarrow schema.
762 """
763 return astropy_to_arrow(self._schema).schema
765 def to_dataframe_schema(self) -> DataFrameSchema:
766 """Convert to a DataFrameSchema.
768 Returns
769 -------
770 dataframe_schema : `DataFrameSchema`
771 Converted dataframe schema.
772 """
773 return DataFrameSchema.from_arrow(astropy_to_arrow(self._schema).schema)
775 def to_arrow_numpy_schema(self) -> ArrowNumpySchema:
776 """Convert to an `ArrowNumpySchema`.
778 Returns
779 -------
780 arrow_numpy_schema : `ArrowNumpySchema`
781 Converted arrow numpy schema.
782 """
783 return ArrowNumpySchema.from_arrow(astropy_to_arrow(self._schema).schema)
785 @property
786 def schema(self) -> atable.Table:
787 return self._schema
789 def __repr__(self) -> str:
790 return repr(self._schema)
792 def __eq__(self, other: object) -> bool:
793 if not isinstance(other, ArrowAstropySchema):
794 return NotImplemented
796 # If this comparison passes then the two tables have the
797 # same column names.
798 if self._schema.dtype != other._schema.dtype:
799 return False
801 for name in self._schema.columns:
802 if not self._schema[name].unit == other._schema[name].unit:
803 return False
804 if not self._schema[name].description == other._schema[name].description:
805 return False
806 if not self._schema[name].format == other._schema[name].format:
807 return False
809 return True
812class ArrowNumpySchema:
813 """Wrapper class for a schema for a numpy ndarray.
815 Parameters
816 ----------
817 numpy_dtype : `numpy.dtype`
818 Numpy dtype to convert.
819 """
821 def __init__(self, numpy_dtype: np.dtype) -> None:
822 self._dtype = numpy_dtype
824 @classmethod
825 def from_arrow(cls, schema: pa.Schema) -> ArrowNumpySchema:
826 """Convert an arrow schema into an `ArrowNumpySchema`.
828 Parameters
829 ----------
830 schema : `pyarrow.Schema`
831 Pyarrow schema to convert.
833 Returns
834 -------
835 numpy_schema : `ArrowNumpySchema`
836 Converted arrow numpy schema.
837 """
838 import numpy as np
840 dtype = _schema_to_dtype_list(schema)
842 return cls(np.dtype(dtype))
844 def to_arrow_astropy_schema(self) -> ArrowAstropySchema:
845 """Convert to an `ArrowAstropySchema`.
847 Returns
848 -------
849 astropy_schema : `ArrowAstropySchema`
850 Converted arrow astropy schema.
851 """
852 import numpy as np
854 return ArrowAstropySchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema)
856 def to_dataframe_schema(self) -> DataFrameSchema:
857 """Convert to a `DataFrameSchema`.
859 Returns
860 -------
861 dataframe_schema : `DataFrameSchema`
862 Converted dataframe schema.
863 """
864 import numpy as np
866 return DataFrameSchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema)
868 def to_arrow_schema(self) -> pa.Schema:
869 """Convert to a `pyarrow.Schema`.
871 Returns
872 -------
873 arrow_schema : `pyarrow.Schema`
874 Converted pyarrow schema.
875 """
876 import numpy as np
878 return numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema
880 @property
881 def schema(self) -> np.dtype:
882 return self._dtype
884 def __repr__(self) -> str:
885 return repr(self._dtype)
887 def __eq__(self, other: object) -> bool:
888 if not isinstance(other, ArrowNumpySchema):
889 return NotImplemented
891 if not self._dtype == other._dtype:
892 return False
894 return True
897def _split_multi_index_column_names(n: int, names: Iterable[str]) -> list[Sequence[str]]:
898 """Split a string that represents a multi-index column.
900 PyArrow maps Pandas' multi-index column names (which are tuples in Python)
901 to flat strings on disk. This routine exists to reconstruct the original
902 tuple.
904 Parameters
905 ----------
906 n : `int`
907 Number of levels in the `pandas.MultiIndex` that is being
908 reconstructed.
909 names : `~collections.abc.Iterable` [`str`]
910 Strings to be split.
912 Returns
913 -------
914 column_names : `list` [`tuple` [`str`]]
915 A list of multi-index column name tuples.
916 """
917 column_names: list[Sequence[str]] = []
919 pattern = re.compile(r"\({}\)".format(", ".join(["'(.*)'"] * n)))
920 for name in names:
921 m = re.search(pattern, name)
922 if m is not None:
923 column_names.append(m.groups())
925 return column_names
928def _standardize_multi_index_columns(
929 pd_index: pd.MultiIndex,
930 columns: Any,
931 stringify: bool = True,
932) -> list[str | Sequence[Any]]:
933 """Transform a dictionary/iterable index from a multi-index column list
934 into a string directly understandable by PyArrow.
936 Parameters
937 ----------
938 pd_index : `pandas.MultiIndex`
939 Pandas multi-index.
940 columns : `list` [`tuple`] or `dict` [`str`, `str` or `list` [`str`]]
941 Columns to standardize.
942 stringify : `bool`, optional
943 Should the column names be stringified?
945 Returns
946 -------
947 names : `list` [`str`]
948 Stringified representation of a multi-index column name.
949 """
950 index_level_names = tuple(pd_index.names)
952 names: list[str | Sequence[Any]] = []
954 if isinstance(columns, list):
955 for requested in columns:
956 if not isinstance(requested, tuple):
957 raise ValueError(
958 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. "
959 f"Instead got a {get_full_type_name(requested)}."
960 )
961 if stringify:
962 names.append(str(requested))
963 else:
964 names.append(requested)
965 else:
966 if not isinstance(columns, collections.abc.Mapping):
967 raise ValueError(
968 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. "
969 f"Instead got a {get_full_type_name(columns)}."
970 )
971 if not set(index_level_names).issuperset(columns.keys()):
972 raise ValueError(
973 f"Cannot use dict with keys {set(columns.keys())} to select columns from {index_level_names}."
974 )
975 factors = [
976 ensure_iterable(columns.get(level, pd_index.levels[i]))
977 for i, level in enumerate(index_level_names)
978 ]
979 for requested in itertools.product(*factors):
980 for i, value in enumerate(requested):
981 if value not in pd_index.levels[i]:
982 raise ValueError(f"Unrecognized value {value!r} for index {index_level_names[i]!r}.")
983 if stringify:
984 names.append(str(requested))
985 else:
986 names.append(requested)
988 return names
991def _apply_astropy_metadata(astropy_table: atable.Table, arrow_schema: pa.Schema) -> None:
992 """Apply any astropy metadata from the schema metadata.
994 Parameters
995 ----------
996 astropy_table : `astropy.table.Table`
997 Table to apply metadata.
998 arrow_schema : `pyarrow.Schema`
999 Arrow schema with metadata.
1000 """
1001 from astropy.table import meta
1003 metadata = arrow_schema.metadata if arrow_schema.metadata is not None else {}
1005 # Check if we have a special astropy metadata header yaml.
1006 meta_yaml = metadata.get(b"table_meta_yaml", None)
1007 if meta_yaml:
1008 meta_yaml = meta_yaml.decode("UTF8").split("\n")
1009 meta_hdr = meta.get_header_from_yaml(meta_yaml)
1011 # Set description, format, unit, meta from the column
1012 # metadata that was serialized with the table.
1013 header_cols = {x["name"]: x for x in meta_hdr["datatype"]}
1014 for col in astropy_table.columns.values():
1015 for attr in ("description", "format", "unit", "meta"):
1016 if attr in header_cols[col.name]:
1017 setattr(col, attr, header_cols[col.name][attr])
1019 if "meta" in meta_hdr:
1020 astropy_table.meta.update(meta_hdr["meta"])
1021 else:
1022 # If we don't have astropy header data, we may have arrow field
1023 # metadata.
1024 for name in arrow_schema.names:
1025 field_metadata = arrow_schema.field(name).metadata
1026 if field_metadata is None:
1027 continue
1028 if (
1029 b"description" in field_metadata
1030 and (description := field_metadata[b"description"].decode("UTF-8")) != ""
1031 ):
1032 astropy_table[name].description = description
1033 if b"unit" in field_metadata and (unit := field_metadata[b"unit"].decode("UTF-8")) != "":
1034 astropy_table[name].unit = unit
1037def _arrow_string_to_numpy_dtype(
1038 schema: pa.Schema, name: str, numpy_column: np.ndarray | None = None, default_length: int = 10
1039) -> str:
1040 """Get the numpy dtype string associated with an arrow column.
1042 Parameters
1043 ----------
1044 schema : `pyarrow.Schema`
1045 Arrow table schema.
1046 name : `str`
1047 Column name.
1048 numpy_column : `numpy.ndarray`, optional
1049 Column to determine numpy string dtype.
1050 default_length : `int`, optional
1051 Default string length when not in metadata or can be inferred
1052 from column.
1054 Returns
1055 -------
1056 dtype_str : `str`
1057 Numpy dtype string.
1058 """
1059 # Special-case for string and binary columns
1060 md_name = f"lsst::arrow::len::{name}"
1061 strlen = default_length
1062 metadata = schema.metadata if schema.metadata is not None else {}
1063 if (encoded := md_name.encode("UTF-8")) in metadata:
1064 # String/bytes length from header.
1065 strlen = int(schema.metadata[encoded])
1066 elif numpy_column is not None and len(numpy_column) > 0:
1067 strlen = max(len(row) for row in numpy_column)
1069 dtype = f"U{strlen}" if schema.field(name).type == pa.string() else f"|S{strlen}"
1071 return dtype
1074def _append_numpy_string_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None:
1075 """Append numpy string length keys to arrow metadata.
1077 All column types are handled, but the metadata is only modified for
1078 string and byte columns.
1080 Parameters
1081 ----------
1082 metadata : `dict` [`bytes`, `str`]
1083 Metadata dictionary; modified in place.
1084 name : `str`
1085 Column name.
1086 dtype : `np.dtype`
1087 Numpy dtype.
1088 """
1089 import numpy as np
1091 if dtype.type is np.str_:
1092 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize // 4)
1093 elif dtype.type is np.bytes_:
1094 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize)
1097def _append_numpy_multidim_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None:
1098 """Append numpy multi-dimensional shapes to arrow metadata.
1100 All column types are handled, but the metadata is only modified for
1101 multi-dimensional columns.
1103 Parameters
1104 ----------
1105 metadata : `dict` [`bytes`, `str`]
1106 Metadata dictionary; modified in place.
1107 name : `str`
1108 Column name.
1109 dtype : `np.dtype`
1110 Numpy dtype.
1111 """
1112 if len(dtype.shape) > 1:
1113 metadata[f"lsst::arrow::shape::{name}".encode()] = str(dtype.shape)
1116def _multidim_shape_from_metadata(metadata: dict[bytes, bytes], list_size: int, name: str) -> tuple[int, ...]:
1117 """Retrieve the shape from the metadata, if available.
1119 Parameters
1120 ----------
1121 metadata : `dict` [`bytes`, `bytes`]
1122 Metadata dictionary.
1123 list_size : `int`
1124 Size of the list datatype.
1125 name : `str`
1126 Column name.
1128 Returns
1129 -------
1130 shape : `tuple` [`int`]
1131 Shape associated with the column.
1133 Raises
1134 ------
1135 RuntimeError
1136 Raised if metadata is found but has incorrect format.
1137 """
1138 md_name = f"lsst::arrow::shape::{name}"
1139 if (encoded := md_name.encode("UTF-8")) in metadata:
1140 groups = re.search(r"\((.*)\)", metadata[encoded].decode("UTF-8"))
1141 if groups is None:
1142 raise RuntimeError("Illegal value found in metadata.")
1143 shape = tuple(int(x) for x in groups[1].split(",") if x != "")
1144 else:
1145 shape = (list_size,)
1147 return shape
1150def _schema_to_dtype_list(schema: pa.Schema) -> list[tuple[str, tuple[Any] | str]]:
1151 """Convert a pyarrow schema to a numpy dtype.
1153 Parameters
1154 ----------
1155 schema : `pyarrow.Schema`
1156 Input pyarrow schema.
1158 Returns
1159 -------
1160 dtype_list: `list` [`tuple`]
1161 A list with name, type pairs.
1162 """
1163 metadata = schema.metadata if schema.metadata is not None else {}
1165 dtype: list[Any] = []
1166 for name in schema.names:
1167 t = schema.field(name).type
1168 if isinstance(t, pa.FixedSizeListType):
1169 shape = _multidim_shape_from_metadata(metadata, t.list_size, name)
1170 dtype.append((name, (t.value_type.to_pandas_dtype(), shape)))
1171 elif t not in (pa.string(), pa.binary()):
1172 dtype.append((name, t.to_pandas_dtype()))
1173 else:
1174 dtype.append((name, _arrow_string_to_numpy_dtype(schema, name)))
1176 return dtype
1179def _numpy_dtype_to_arrow_types(dtype: np.dtype) -> list[Any]:
1180 """Convert a numpy dtype to a list of arrow types.
1182 Parameters
1183 ----------
1184 dtype : `numpy.dtype`
1185 Numpy dtype to convert.
1187 Returns
1188 -------
1189 type_list : `list` [`object`]
1190 Converted list of arrow types.
1191 """
1192 from math import prod
1194 import numpy as np
1196 type_list: list[Any] = []
1197 if dtype.names is None:
1198 return type_list
1200 for name in dtype.names:
1201 dt = dtype[name]
1202 arrow_type: Any
1203 if len(dt.shape) > 0:
1204 arrow_type = pa.list_(
1205 pa.from_numpy_dtype(cast(tuple[np.dtype, tuple[int, ...]], dt.subdtype)[0].type),
1206 prod(dt.shape),
1207 )
1208 else:
1209 arrow_type = pa.from_numpy_dtype(dt.type)
1210 type_list.append((name, arrow_type))
1212 return type_list
1215def _numpy_dict_to_dtype(numpy_dict: dict[str, np.ndarray]) -> tuple[np.dtype, int]:
1216 """Extract equivalent table dtype from dict of numpy arrays.
1218 Parameters
1219 ----------
1220 numpy_dict : `dict` [`str`, `numpy.ndarray`]
1221 Dict with keys as the column names, values as the arrays.
1223 Returns
1224 -------
1225 dtype : `numpy.dtype`
1226 dtype of equivalent table.
1227 rowcount : `int`
1228 Number of rows in the table.
1230 Raises
1231 ------
1232 ValueError if columns in numpy_dict have unequal numbers of rows.
1233 """
1234 import numpy as np
1236 dtype_list = []
1237 rowcount = 0
1238 for name, col in numpy_dict.items():
1239 if rowcount == 0:
1240 rowcount = len(col)
1241 if len(col) != rowcount:
1242 raise ValueError(f"Column {name} has a different number of rows.")
1243 if len(col.shape) == 1:
1244 dtype_list.append((name, col.dtype))
1245 else:
1246 dtype_list.append((name, (col.dtype, col.shape[1:])))
1247 dtype = np.dtype(dtype_list)
1249 return (dtype, rowcount)
1252def _numpy_style_arrays_to_arrow_arrays(
1253 dtype: np.dtype,
1254 rowcount: int,
1255 np_style_arrays: dict[str, np.ndarray] | np.ndarray | atable.Table,
1256 schema: pa.Schema,
1257) -> list[pa.Array]:
1258 """Convert numpy-style arrays to arrow arrays.
1260 Parameters
1261 ----------
1262 dtype : `numpy.dtype`
1263 Numpy dtype of input table/arrays.
1264 rowcount : `int`
1265 Number of rows in input table/arrays.
1266 np_style_arrays : `dict` [`str`, `np.ndarray`] or `np.ndarray`
1267 or `astropy.table.Table`
1268 Arrays to convert to arrow.
1269 schema : `pyarrow.Schema`
1270 Schema of arrow table.
1272 Returns
1273 -------
1274 arrow_arrays : `list` [`pyarrow.Array`]
1275 List of converted pyarrow arrays.
1276 """
1277 import numpy as np
1279 arrow_arrays: list[pa.Array] = []
1280 if dtype.names is None:
1281 return arrow_arrays
1283 for name in dtype.names:
1284 dt = dtype[name]
1285 val: Any
1286 if len(dt.shape) > 0:
1287 if rowcount > 0:
1288 val = np.split(np_style_arrays[name].ravel(), rowcount)
1289 else:
1290 val = []
1291 else:
1292 val = np_style_arrays[name]
1294 try:
1295 arrow_arrays.append(pa.array(val, type=schema.field(name).type))
1296 except pa.ArrowNotImplementedError as err:
1297 # Check if val is big-endian.
1298 if (np.little_endian and val.dtype.byteorder == ">") or (
1299 not np.little_endian and val.dtype.byteorder == "="
1300 ):
1301 # We need to convert the array to little-endian.
1302 val2 = val.byteswap()
1303 val2.dtype = val2.dtype.newbyteorder("<")
1304 arrow_arrays.append(pa.array(val2, type=schema.field(name).type))
1305 else:
1306 # This failed for some other reason so raise the exception.
1307 raise err
1309 return arrow_arrays
1312def compute_row_group_size(schema: pa.Schema, target_size: int = TARGET_ROW_GROUP_BYTES) -> int:
1313 """Compute approximate row group size for a given arrow schema.
1315 Given a schema, this routine will compute the number of rows in a row group
1316 that targets the persisted size on disk (or smaller). The exact size on
1317 disk depends on the compression settings and ratios; typical binary data
1318 tables will have around 15-20% compression with the pyarrow default
1319 ``snappy`` compression algorithm.
1321 Parameters
1322 ----------
1323 schema : `pyarrow.Schema`
1324 Arrow table schema.
1325 target_size : `int`, optional
1326 The target size (in bytes).
1328 Returns
1329 -------
1330 row_group_size : `int`
1331 Number of rows per row group to hit the target size.
1332 """
1333 bit_width = 0
1335 metadata = schema.metadata if schema.metadata is not None else {}
1337 for name in schema.names:
1338 t = schema.field(name).type
1340 if t in (pa.string(), pa.binary()):
1341 md_name = f"lsst::arrow::len::{name}"
1343 if (encoded := md_name.encode("UTF-8")) in metadata:
1344 # String/bytes length from header.
1345 strlen = int(schema.metadata[encoded])
1346 else:
1347 # We don't know the string width, so guess something.
1348 strlen = 10
1350 # Assuming UTF-8 encoding, and very few wide characters.
1351 t_width = 8 * strlen
1352 elif isinstance(t, pa.FixedSizeListType):
1353 if t.value_type == pa.null():
1354 t_width = 0
1355 else:
1356 t_width = t.list_size * t.value_type.bit_width
1357 elif t == pa.null():
1358 t_width = 0
1359 elif isinstance(t, pa.ListType):
1360 if t.value_type == pa.null():
1361 t_width = 0
1362 else:
1363 # This is a variable length list, just choose
1364 # something arbitrary.
1365 t_width = 10 * t.value_type.bit_width
1366 else:
1367 t_width = t.bit_width
1369 bit_width += t_width
1371 # Insist it is at least 1 byte wide to avoid any divide-by-zero errors.
1372 if bit_width < 8:
1373 bit_width = 8
1375 byte_width = bit_width // 8
1377 return target_size // byte_width