Coverage for python/lsst/daf/butler/formatters/parquet.py: 13%

461 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-25 10:50 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ( 

31 "ParquetFormatter", 

32 "arrow_to_pandas", 

33 "arrow_to_astropy", 

34 "arrow_to_numpy", 

35 "arrow_to_numpy_dict", 

36 "pandas_to_arrow", 

37 "pandas_to_astropy", 

38 "astropy_to_arrow", 

39 "numpy_to_arrow", 

40 "numpy_to_astropy", 

41 "numpy_dict_to_arrow", 

42 "arrow_schema_to_pandas_index", 

43 "DataFrameSchema", 

44 "ArrowAstropySchema", 

45 "ArrowNumpySchema", 

46 "compute_row_group_size", 

47) 

48 

49import collections.abc 

50import itertools 

51import json 

52import re 

53from collections.abc import Iterable, Sequence 

54from typing import TYPE_CHECKING, Any, cast 

55 

56import pyarrow as pa 

57import pyarrow.parquet as pq 

58from lsst.daf.butler import Formatter 

59from lsst.utils.introspection import get_full_type_name 

60from lsst.utils.iteration import ensure_iterable 

61 

62if TYPE_CHECKING: 

63 import astropy.table as atable 

64 import numpy as np 

65 import pandas as pd 

66 

67TARGET_ROW_GROUP_BYTES = 1_000_000_000 

68 

69 

70class ParquetFormatter(Formatter): 

71 """Interface for reading and writing Arrow Table objects to and from 

72 Parquet files. 

73 """ 

74 

75 extension = ".parq" 

76 

77 def read(self, component: str | None = None) -> Any: 

78 # Docstring inherited from Formatter.read. 

79 schema = pq.read_schema(self.fileDescriptor.location.path) 

80 

81 schema_names = ["ArrowSchema", "DataFrameSchema", "ArrowAstropySchema", "ArrowNumpySchema"] 

82 

83 if component in ("columns", "schema") or self.fileDescriptor.readStorageClass.name in schema_names: 

84 # The schema will be translated to column format 

85 # depending on the input type. 

86 return schema 

87 elif component == "rowcount": 

88 # Get the rowcount from the metadata if possible, otherwise count. 

89 if b"lsst::arrow::rowcount" in schema.metadata: 

90 return int(schema.metadata[b"lsst::arrow::rowcount"]) 

91 

92 temp_table = pq.read_table( 

93 self.fileDescriptor.location.path, 

94 columns=[schema.names[0]], 

95 use_threads=False, 

96 use_pandas_metadata=False, 

97 ) 

98 

99 return len(temp_table[schema.names[0]]) 

100 

101 par_columns = None 

102 if self.fileDescriptor.parameters: 

103 par_columns = self.fileDescriptor.parameters.pop("columns", None) 

104 if par_columns: 

105 has_pandas_multi_index = False 

106 if b"pandas" in schema.metadata: 

107 md = json.loads(schema.metadata[b"pandas"]) 

108 if len(md["column_indexes"]) > 1: 

109 has_pandas_multi_index = True 

110 

111 if not has_pandas_multi_index: 

112 # Ensure uniqueness, keeping order. 

113 par_columns = list(dict.fromkeys(ensure_iterable(par_columns))) 

114 file_columns = [name for name in schema.names if not name.startswith("__")] 

115 

116 for par_column in par_columns: 

117 if par_column not in file_columns: 

118 raise ValueError( 

119 f"Column {par_column} specified in parameters not available in parquet file." 

120 ) 

121 else: 

122 par_columns = _standardize_multi_index_columns( 

123 arrow_schema_to_pandas_index(schema), 

124 par_columns, 

125 ) 

126 

127 if len(self.fileDescriptor.parameters): 

128 raise ValueError( 

129 f"Unsupported parameters {self.fileDescriptor.parameters} in ArrowTable read." 

130 ) 

131 

132 metadata = schema.metadata if schema.metadata is not None else {} 

133 arrow_table = pq.read_table( 

134 self.fileDescriptor.location.path, 

135 columns=par_columns, 

136 use_threads=False, 

137 use_pandas_metadata=(b"pandas" in metadata), 

138 ) 

139 

140 return arrow_table 

141 

142 def write(self, inMemoryDataset: Any) -> None: 

143 import numpy as np 

144 from astropy.table import Table as astropyTable 

145 

146 location = self.makeUpdatedLocation(self.fileDescriptor.location) 

147 

148 arrow_table = None 

149 if isinstance(inMemoryDataset, pa.Table): 

150 # This will be the most likely match. 

151 arrow_table = inMemoryDataset 

152 elif isinstance(inMemoryDataset, astropyTable): 

153 arrow_table = astropy_to_arrow(inMemoryDataset) 

154 elif isinstance(inMemoryDataset, np.ndarray): 

155 arrow_table = numpy_to_arrow(inMemoryDataset) 

156 elif isinstance(inMemoryDataset, dict): 

157 try: 

158 arrow_table = numpy_dict_to_arrow(inMemoryDataset) 

159 except (TypeError, AttributeError) as e: 

160 raise ValueError( 

161 "Input dict for inMemoryDataset does not appear to be a dict of numpy arrays." 

162 ) from e 

163 elif isinstance(inMemoryDataset, pa.Schema): 

164 pq.write_metadata(inMemoryDataset, location.path) 

165 return 

166 else: 

167 if hasattr(inMemoryDataset, "to_parquet"): 

168 # This may be a pandas DataFrame 

169 try: 

170 import pandas as pd 

171 except ImportError: 

172 pd = None 

173 

174 if pd is not None and isinstance(inMemoryDataset, pd.DataFrame): 

175 arrow_table = pandas_to_arrow(inMemoryDataset) 

176 

177 if arrow_table is None: 

178 raise ValueError( 

179 f"Unsupported type {get_full_type_name(inMemoryDataset)} of " 

180 "inMemoryDataset for ParquetFormatter." 

181 ) 

182 

183 row_group_size = compute_row_group_size(arrow_table.schema) 

184 

185 pq.write_table(arrow_table, location.path, row_group_size=row_group_size) 

186 

187 

188def arrow_to_pandas(arrow_table: pa.Table) -> pd.DataFrame: 

189 """Convert a pyarrow table to a pandas DataFrame. 

190 

191 Parameters 

192 ---------- 

193 arrow_table : `pyarrow.Table` 

194 Input arrow table to convert. If the table has ``pandas`` metadata 

195 in the schema it will be used in the construction of the 

196 ``DataFrame``. 

197 

198 Returns 

199 ------- 

200 dataframe : `pandas.DataFrame` 

201 Converted pandas dataframe. 

202 """ 

203 return arrow_table.to_pandas(use_threads=False, integer_object_nulls=True) 

204 

205 

206def arrow_to_astropy(arrow_table: pa.Table) -> atable.Table: 

207 """Convert a pyarrow table to an `astropy.Table`. 

208 

209 Parameters 

210 ---------- 

211 arrow_table : `pyarrow.Table` 

212 Input arrow table to convert. If the table has astropy unit 

213 metadata in the schema it will be used in the construction 

214 of the ``astropy.Table``. 

215 

216 Returns 

217 ------- 

218 table : `astropy.Table` 

219 Converted astropy table. 

220 """ 

221 from astropy.table import Table 

222 

223 astropy_table = Table(arrow_to_numpy_dict(arrow_table)) 

224 

225 _apply_astropy_metadata(astropy_table, arrow_table.schema) 

226 

227 return astropy_table 

228 

229 

230def arrow_to_numpy(arrow_table: pa.Table) -> np.ndarray: 

231 """Convert a pyarrow table to a structured numpy array. 

232 

233 Parameters 

234 ---------- 

235 arrow_table : `pyarrow.Table` 

236 Input arrow table. 

237 

238 Returns 

239 ------- 

240 array : `numpy.ndarray` (N,) 

241 Numpy array table with N rows and the same column names 

242 as the input arrow table. 

243 """ 

244 import numpy as np 

245 

246 numpy_dict = arrow_to_numpy_dict(arrow_table) 

247 

248 dtype = [] 

249 for name, col in numpy_dict.items(): 

250 if len(shape := numpy_dict[name].shape) <= 1: 

251 dtype.append((name, col.dtype)) 

252 else: 

253 dtype.append((name, (col.dtype, shape[1:]))) 

254 

255 array = np.rec.fromarrays(numpy_dict.values(), dtype=dtype) 

256 

257 return array 

258 

259 

260def arrow_to_numpy_dict(arrow_table: pa.Table) -> dict[str, np.ndarray]: 

261 """Convert a pyarrow table to a dict of numpy arrays. 

262 

263 Parameters 

264 ---------- 

265 arrow_table : `pyarrow.Table` 

266 Input arrow table. 

267 

268 Returns 

269 ------- 

270 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

271 Dict with keys as the column names, values as the arrays. 

272 """ 

273 import numpy as np 

274 

275 schema = arrow_table.schema 

276 metadata = schema.metadata if schema.metadata is not None else {} 

277 

278 numpy_dict = {} 

279 

280 for name in schema.names: 

281 t = schema.field(name).type 

282 

283 if arrow_table[name].null_count == 0: 

284 # Regular non-masked column 

285 col = arrow_table[name].to_numpy() 

286 else: 

287 # For a masked column, we need to ask arrow to fill the null 

288 # values with an appropriately typed value before conversion. 

289 # Then we apply the mask to get a masked array of the correct type. 

290 

291 if t in (pa.string(), pa.binary()): 

292 dummy = "" 

293 else: 

294 dummy = t.to_pandas_dtype()(0) 

295 

296 col = np.ma.masked_array( 

297 data=arrow_table[name].fill_null(dummy).to_numpy(), 

298 mask=arrow_table[name].is_null().to_numpy(), 

299 ) 

300 

301 if t in (pa.string(), pa.binary()): 

302 col = col.astype(_arrow_string_to_numpy_dtype(schema, name, col)) 

303 elif isinstance(t, pa.FixedSizeListType): 

304 if len(col) > 0: 

305 col = np.stack(col) 

306 else: 

307 # this is an empty column, and needs to be coerced to type. 

308 col = col.astype(t.value_type.to_pandas_dtype()) 

309 

310 shape = _multidim_shape_from_metadata(metadata, t.list_size, name) 

311 col = col.reshape((len(arrow_table), *shape)) 

312 

313 numpy_dict[name] = col 

314 

315 return numpy_dict 

316 

317 

318def _numpy_dict_to_numpy(numpy_dict: dict[str, np.ndarray]) -> np.ndarray: 

319 """Convert a dict of numpy arrays to a structured numpy array. 

320 

321 Parameters 

322 ---------- 

323 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

324 Dict with keys as the column names, values as the arrays. 

325 

326 Returns 

327 ------- 

328 array : `numpy.ndarray` (N,) 

329 Numpy array table with N rows and columns names from the dict keys. 

330 """ 

331 return arrow_to_numpy(numpy_dict_to_arrow(numpy_dict)) 

332 

333 

334def _numpy_to_numpy_dict(np_array: np.ndarray) -> dict[str, np.ndarray]: 

335 """Convert a structured numpy array to a dict of numpy arrays. 

336 

337 Parameters 

338 ---------- 

339 np_array : `numpy.ndarray` 

340 Input numpy array with multiple fields. 

341 

342 Returns 

343 ------- 

344 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

345 Dict with keys as the column names, values as the arrays. 

346 """ 

347 return arrow_to_numpy_dict(numpy_to_arrow(np_array)) 

348 

349 

350def numpy_to_arrow(np_array: np.ndarray) -> pa.Table: 

351 """Convert a numpy array table to an arrow table. 

352 

353 Parameters 

354 ---------- 

355 np_array : `numpy.ndarray` 

356 Input numpy array with multiple fields. 

357 

358 Returns 

359 ------- 

360 arrow_table : `pyarrow.Table` 

361 Converted arrow table. 

362 """ 

363 type_list = _numpy_dtype_to_arrow_types(np_array.dtype) 

364 

365 md = {} 

366 md[b"lsst::arrow::rowcount"] = str(len(np_array)) 

367 

368 for name in np_array.dtype.names: 

369 _append_numpy_string_metadata(md, name, np_array.dtype[name]) 

370 _append_numpy_multidim_metadata(md, name, np_array.dtype[name]) 

371 

372 schema = pa.schema(type_list, metadata=md) 

373 

374 arrays = _numpy_style_arrays_to_arrow_arrays( 

375 np_array.dtype, 

376 len(np_array), 

377 np_array, 

378 schema, 

379 ) 

380 

381 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

382 

383 return arrow_table 

384 

385 

386def numpy_dict_to_arrow(numpy_dict: dict[str, np.ndarray]) -> pa.Table: 

387 """Convert a dict of numpy arrays to an arrow table. 

388 

389 Parameters 

390 ---------- 

391 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

392 Dict with keys as the column names, values as the arrays. 

393 

394 Returns 

395 ------- 

396 arrow_table : `pyarrow.Table` 

397 Converted arrow table. 

398 

399 Raises 

400 ------ 

401 ValueError if columns in numpy_dict have unequal numbers of rows. 

402 """ 

403 dtype, rowcount = _numpy_dict_to_dtype(numpy_dict) 

404 type_list = _numpy_dtype_to_arrow_types(dtype) 

405 

406 md = {} 

407 md[b"lsst::arrow::rowcount"] = str(rowcount) 

408 

409 if dtype.names is not None: 

410 for name in dtype.names: 

411 _append_numpy_string_metadata(md, name, dtype[name]) 

412 _append_numpy_multidim_metadata(md, name, dtype[name]) 

413 

414 schema = pa.schema(type_list, metadata=md) 

415 

416 arrays = _numpy_style_arrays_to_arrow_arrays( 

417 dtype, 

418 rowcount, 

419 numpy_dict, 

420 schema, 

421 ) 

422 

423 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

424 

425 return arrow_table 

426 

427 

428def astropy_to_arrow(astropy_table: atable.Table) -> pa.Table: 

429 """Convert an astropy table to an arrow table. 

430 

431 Parameters 

432 ---------- 

433 astropy_table : `astropy.Table` 

434 Input astropy table. 

435 

436 Returns 

437 ------- 

438 arrow_table : `pyarrow.Table` 

439 Converted arrow table. 

440 """ 

441 from astropy.table import meta 

442 

443 type_list = _numpy_dtype_to_arrow_types(astropy_table.dtype) 

444 

445 md = {} 

446 md[b"lsst::arrow::rowcount"] = str(len(astropy_table)) 

447 

448 for name in astropy_table.dtype.names: 

449 _append_numpy_string_metadata(md, name, astropy_table.dtype[name]) 

450 _append_numpy_multidim_metadata(md, name, astropy_table.dtype[name]) 

451 

452 meta_yaml = meta.get_yaml_from_table(astropy_table) 

453 meta_yaml_str = "\n".join(meta_yaml) 

454 md[b"table_meta_yaml"] = meta_yaml_str 

455 

456 # Convert type list to fields with metadata. 

457 fields = [] 

458 for name, pa_type in type_list: 

459 field_metadata = {} 

460 if description := astropy_table[name].description: 

461 field_metadata["description"] = description 

462 if unit := astropy_table[name].unit: 

463 field_metadata["unit"] = str(unit) 

464 fields.append( 

465 pa.field( 

466 name, 

467 pa_type, 

468 metadata=field_metadata, 

469 ) 

470 ) 

471 

472 schema = pa.schema(fields, metadata=md) 

473 

474 arrays = _numpy_style_arrays_to_arrow_arrays( 

475 astropy_table.dtype, 

476 len(astropy_table), 

477 astropy_table, 

478 schema, 

479 ) 

480 

481 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

482 

483 return arrow_table 

484 

485 

486def _astropy_to_numpy_dict(astropy_table: atable.Table) -> dict[str, np.ndarray]: 

487 """Convert an astropy table to an arrow table. 

488 

489 Parameters 

490 ---------- 

491 astropy_table : `astropy.Table` 

492 Input astropy table. 

493 

494 Returns 

495 ------- 

496 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

497 Dict with keys as the column names, values as the arrays. 

498 """ 

499 return arrow_to_numpy_dict(astropy_to_arrow(astropy_table)) 

500 

501 

502def pandas_to_arrow(dataframe: pd.DataFrame, default_length: int = 10) -> pa.Table: 

503 """Convert a pandas dataframe to an arrow table. 

504 

505 Parameters 

506 ---------- 

507 dataframe : `pandas.DataFrame` 

508 Input pandas dataframe. 

509 default_length : `int`, optional 

510 Default string length when not in metadata or can be inferred 

511 from column. 

512 

513 Returns 

514 ------- 

515 arrow_table : `pyarrow.Table` 

516 Converted arrow table. 

517 """ 

518 arrow_table = pa.Table.from_pandas(dataframe) 

519 

520 # Update the metadata 

521 md = arrow_table.schema.metadata 

522 

523 md[b"lsst::arrow::rowcount"] = str(arrow_table.num_rows) 

524 

525 # We loop through the arrow table columns because the datatypes have 

526 # been checked and converted from pandas objects. 

527 for name in arrow_table.column_names: 

528 if not name.startswith("__") and arrow_table[name].type == pa.string(): 

529 if len(arrow_table[name]) > 0: 

530 strlen = max(len(row.as_py()) for row in arrow_table[name] if row.is_valid) 

531 else: 

532 strlen = default_length 

533 md[f"lsst::arrow::len::{name}".encode()] = str(strlen) 

534 

535 arrow_table = arrow_table.replace_schema_metadata(md) 

536 

537 return arrow_table 

538 

539 

540def pandas_to_astropy(dataframe: pd.DataFrame) -> atable.Table: 

541 """Convert a pandas dataframe to an astropy table, preserving indexes. 

542 

543 Parameters 

544 ---------- 

545 dataframe : `pandas.DataFrame` 

546 Input pandas dataframe. 

547 

548 Returns 

549 ------- 

550 astropy_table : `astropy.table.Table` 

551 Converted astropy table. 

552 """ 

553 import pandas as pd 

554 from astropy.table import Table 

555 

556 if isinstance(dataframe.columns, pd.MultiIndex): 

557 raise ValueError("Cannot convert a multi-index dataframe to an astropy table.") 

558 

559 return Table.from_pandas(dataframe, index=True) 

560 

561 

562def _pandas_to_numpy_dict(dataframe: pd.DataFrame) -> dict[str, np.ndarray]: 

563 """Convert a pandas dataframe to an dict of numpy arrays. 

564 

565 Parameters 

566 ---------- 

567 dataframe : `pandas.DataFrame` 

568 Input pandas dataframe. 

569 

570 Returns 

571 ------- 

572 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

573 Dict with keys as the column names, values as the arrays. 

574 """ 

575 return arrow_to_numpy_dict(pandas_to_arrow(dataframe)) 

576 

577 

578def numpy_to_astropy(np_array: np.ndarray) -> atable.Table: 

579 """Convert a numpy table to an astropy table. 

580 

581 Parameters 

582 ---------- 

583 np_array : `numpy.ndarray` 

584 Input numpy array with multiple fields. 

585 

586 Returns 

587 ------- 

588 astropy_table : `astropy.table.Table` 

589 Converted astropy table. 

590 """ 

591 from astropy.table import Table 

592 

593 return Table(data=np_array, copy=False) 

594 

595 

596def arrow_schema_to_pandas_index(schema: pa.Schema) -> pd.Index | pd.MultiIndex: 

597 """Convert an arrow schema to a pandas index/multiindex. 

598 

599 Parameters 

600 ---------- 

601 schema : `pyarrow.Schema` 

602 Input pyarrow schema. 

603 

604 Returns 

605 ------- 

606 index : `pandas.Index` or `pandas.MultiIndex` 

607 Converted pandas index. 

608 """ 

609 import pandas as pd 

610 

611 if b"pandas" in schema.metadata: 

612 md = json.loads(schema.metadata[b"pandas"]) 

613 indexes = md["column_indexes"] 

614 len_indexes = len(indexes) 

615 else: 

616 len_indexes = 0 

617 

618 if len_indexes <= 1: 

619 return pd.Index(name for name in schema.names if not name.startswith("__")) 

620 else: 

621 raw_columns = _split_multi_index_column_names(len(indexes), schema.names) 

622 return pd.MultiIndex.from_tuples(raw_columns, names=[f["name"] for f in indexes]) 

623 

624 

625def arrow_schema_to_column_list(schema: pa.Schema) -> list[str]: 

626 """Convert an arrow schema to a list of string column names. 

627 

628 Parameters 

629 ---------- 

630 schema : `pyarrow.Schema` 

631 Input pyarrow schema. 

632 

633 Returns 

634 ------- 

635 column_list : `list` [`str`] 

636 Converted list of column names. 

637 """ 

638 return list(schema.names) 

639 

640 

641class DataFrameSchema: 

642 """Wrapper class for a schema for a pandas DataFrame. 

643 

644 Parameters 

645 ---------- 

646 dataframe : `pandas.DataFrame` 

647 Dataframe to turn into a schema. 

648 """ 

649 

650 def __init__(self, dataframe: pd.DataFrame) -> None: 

651 self._schema = dataframe.loc[[False] * len(dataframe)] 

652 

653 @classmethod 

654 def from_arrow(cls, schema: pa.Schema) -> DataFrameSchema: 

655 """Convert an arrow schema into a `DataFrameSchema`. 

656 

657 Parameters 

658 ---------- 

659 schema : `pyarrow.Schema` 

660 The pyarrow schema to convert. 

661 

662 Returns 

663 ------- 

664 dataframe_schema : `DataFrameSchema` 

665 Converted dataframe schema. 

666 """ 

667 empty_table = pa.Table.from_pylist([] * len(schema.names), schema=schema) 

668 

669 return cls(empty_table.to_pandas()) 

670 

671 def to_arrow_schema(self) -> pa.Schema: 

672 """Convert to an arrow schema. 

673 

674 Returns 

675 ------- 

676 arrow_schema : `pyarrow.Schema` 

677 Converted pyarrow schema. 

678 """ 

679 arrow_table = pa.Table.from_pandas(self._schema) 

680 

681 return arrow_table.schema 

682 

683 def to_arrow_numpy_schema(self) -> ArrowNumpySchema: 

684 """Convert to an `ArrowNumpySchema`. 

685 

686 Returns 

687 ------- 

688 arrow_numpy_schema : `ArrowNumpySchema` 

689 Converted arrow numpy schema. 

690 """ 

691 return ArrowNumpySchema.from_arrow(self.to_arrow_schema()) 

692 

693 def to_arrow_astropy_schema(self) -> ArrowAstropySchema: 

694 """Convert to an ArrowAstropySchema. 

695 

696 Returns 

697 ------- 

698 arrow_astropy_schema : `ArrowAstropySchema` 

699 Converted arrow astropy schema. 

700 """ 

701 return ArrowAstropySchema.from_arrow(self.to_arrow_schema()) 

702 

703 @property 

704 def schema(self) -> np.dtype: 

705 return self._schema 

706 

707 def __repr__(self) -> str: 

708 return repr(self._schema) 

709 

710 def __eq__(self, other: object) -> bool: 

711 if not isinstance(other, DataFrameSchema): 

712 return NotImplemented 

713 

714 return self._schema.equals(other._schema) 

715 

716 

717class ArrowAstropySchema: 

718 """Wrapper class for a schema for an astropy table. 

719 

720 Parameters 

721 ---------- 

722 astropy_table : `astropy.table.Table` 

723 Input astropy table. 

724 """ 

725 

726 def __init__(self, astropy_table: atable.Table) -> None: 

727 self._schema = astropy_table[:0] 

728 

729 @classmethod 

730 def from_arrow(cls, schema: pa.Schema) -> ArrowAstropySchema: 

731 """Convert an arrow schema into a ArrowAstropySchema. 

732 

733 Parameters 

734 ---------- 

735 schema : `pyarrow.Schema` 

736 Input pyarrow schema. 

737 

738 Returns 

739 ------- 

740 astropy_schema : `ArrowAstropySchema` 

741 Converted arrow astropy schema. 

742 """ 

743 import numpy as np 

744 from astropy.table import Table 

745 

746 dtype = _schema_to_dtype_list(schema) 

747 

748 data = np.zeros(0, dtype=dtype) 

749 astropy_table = Table(data=data) 

750 

751 _apply_astropy_metadata(astropy_table, schema) 

752 

753 return cls(astropy_table) 

754 

755 def to_arrow_schema(self) -> pa.Schema: 

756 """Convert to an arrow schema. 

757 

758 Returns 

759 ------- 

760 arrow_schema : `pyarrow.Schema` 

761 Converted pyarrow schema. 

762 """ 

763 return astropy_to_arrow(self._schema).schema 

764 

765 def to_dataframe_schema(self) -> DataFrameSchema: 

766 """Convert to a DataFrameSchema. 

767 

768 Returns 

769 ------- 

770 dataframe_schema : `DataFrameSchema` 

771 Converted dataframe schema. 

772 """ 

773 return DataFrameSchema.from_arrow(astropy_to_arrow(self._schema).schema) 

774 

775 def to_arrow_numpy_schema(self) -> ArrowNumpySchema: 

776 """Convert to an `ArrowNumpySchema`. 

777 

778 Returns 

779 ------- 

780 arrow_numpy_schema : `ArrowNumpySchema` 

781 Converted arrow numpy schema. 

782 """ 

783 return ArrowNumpySchema.from_arrow(astropy_to_arrow(self._schema).schema) 

784 

785 @property 

786 def schema(self) -> atable.Table: 

787 return self._schema 

788 

789 def __repr__(self) -> str: 

790 return repr(self._schema) 

791 

792 def __eq__(self, other: object) -> bool: 

793 if not isinstance(other, ArrowAstropySchema): 

794 return NotImplemented 

795 

796 # If this comparison passes then the two tables have the 

797 # same column names. 

798 if self._schema.dtype != other._schema.dtype: 

799 return False 

800 

801 for name in self._schema.columns: 

802 if not self._schema[name].unit == other._schema[name].unit: 

803 return False 

804 if not self._schema[name].description == other._schema[name].description: 

805 return False 

806 if not self._schema[name].format == other._schema[name].format: 

807 return False 

808 

809 return True 

810 

811 

812class ArrowNumpySchema: 

813 """Wrapper class for a schema for a numpy ndarray. 

814 

815 Parameters 

816 ---------- 

817 numpy_dtype : `numpy.dtype` 

818 Numpy dtype to convert. 

819 """ 

820 

821 def __init__(self, numpy_dtype: np.dtype) -> None: 

822 self._dtype = numpy_dtype 

823 

824 @classmethod 

825 def from_arrow(cls, schema: pa.Schema) -> ArrowNumpySchema: 

826 """Convert an arrow schema into an `ArrowNumpySchema`. 

827 

828 Parameters 

829 ---------- 

830 schema : `pyarrow.Schema` 

831 Pyarrow schema to convert. 

832 

833 Returns 

834 ------- 

835 numpy_schema : `ArrowNumpySchema` 

836 Converted arrow numpy schema. 

837 """ 

838 import numpy as np 

839 

840 dtype = _schema_to_dtype_list(schema) 

841 

842 return cls(np.dtype(dtype)) 

843 

844 def to_arrow_astropy_schema(self) -> ArrowAstropySchema: 

845 """Convert to an `ArrowAstropySchema`. 

846 

847 Returns 

848 ------- 

849 astropy_schema : `ArrowAstropySchema` 

850 Converted arrow astropy schema. 

851 """ 

852 import numpy as np 

853 

854 return ArrowAstropySchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema) 

855 

856 def to_dataframe_schema(self) -> DataFrameSchema: 

857 """Convert to a `DataFrameSchema`. 

858 

859 Returns 

860 ------- 

861 dataframe_schema : `DataFrameSchema` 

862 Converted dataframe schema. 

863 """ 

864 import numpy as np 

865 

866 return DataFrameSchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema) 

867 

868 def to_arrow_schema(self) -> pa.Schema: 

869 """Convert to a `pyarrow.Schema`. 

870 

871 Returns 

872 ------- 

873 arrow_schema : `pyarrow.Schema` 

874 Converted pyarrow schema. 

875 """ 

876 import numpy as np 

877 

878 return numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema 

879 

880 @property 

881 def schema(self) -> np.dtype: 

882 return self._dtype 

883 

884 def __repr__(self) -> str: 

885 return repr(self._dtype) 

886 

887 def __eq__(self, other: object) -> bool: 

888 if not isinstance(other, ArrowNumpySchema): 

889 return NotImplemented 

890 

891 if not self._dtype == other._dtype: 

892 return False 

893 

894 return True 

895 

896 

897def _split_multi_index_column_names(n: int, names: Iterable[str]) -> list[Sequence[str]]: 

898 """Split a string that represents a multi-index column. 

899 

900 PyArrow maps Pandas' multi-index column names (which are tuples in Python) 

901 to flat strings on disk. This routine exists to reconstruct the original 

902 tuple. 

903 

904 Parameters 

905 ---------- 

906 n : `int` 

907 Number of levels in the `pandas.MultiIndex` that is being 

908 reconstructed. 

909 names : `~collections.abc.Iterable` [`str`] 

910 Strings to be split. 

911 

912 Returns 

913 ------- 

914 column_names : `list` [`tuple` [`str`]] 

915 A list of multi-index column name tuples. 

916 """ 

917 column_names: list[Sequence[str]] = [] 

918 

919 pattern = re.compile(r"\({}\)".format(", ".join(["'(.*)'"] * n))) 

920 for name in names: 

921 m = re.search(pattern, name) 

922 if m is not None: 

923 column_names.append(m.groups()) 

924 

925 return column_names 

926 

927 

928def _standardize_multi_index_columns( 

929 pd_index: pd.MultiIndex, 

930 columns: Any, 

931 stringify: bool = True, 

932) -> list[str | Sequence[Any]]: 

933 """Transform a dictionary/iterable index from a multi-index column list 

934 into a string directly understandable by PyArrow. 

935 

936 Parameters 

937 ---------- 

938 pd_index : `pandas.MultiIndex` 

939 Pandas multi-index. 

940 columns : `list` [`tuple`] or `dict` [`str`, `str` or `list` [`str`]] 

941 Columns to standardize. 

942 stringify : `bool`, optional 

943 Should the column names be stringified? 

944 

945 Returns 

946 ------- 

947 names : `list` [`str`] 

948 Stringified representation of a multi-index column name. 

949 """ 

950 index_level_names = tuple(pd_index.names) 

951 

952 names: list[str | Sequence[Any]] = [] 

953 

954 if isinstance(columns, list): 

955 for requested in columns: 

956 if not isinstance(requested, tuple): 

957 raise ValueError( 

958 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. " 

959 f"Instead got a {get_full_type_name(requested)}." 

960 ) 

961 if stringify: 

962 names.append(str(requested)) 

963 else: 

964 names.append(requested) 

965 else: 

966 if not isinstance(columns, collections.abc.Mapping): 

967 raise ValueError( 

968 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. " 

969 f"Instead got a {get_full_type_name(columns)}." 

970 ) 

971 if not set(index_level_names).issuperset(columns.keys()): 

972 raise ValueError( 

973 f"Cannot use dict with keys {set(columns.keys())} to select columns from {index_level_names}." 

974 ) 

975 factors = [ 

976 ensure_iterable(columns.get(level, pd_index.levels[i])) 

977 for i, level in enumerate(index_level_names) 

978 ] 

979 for requested in itertools.product(*factors): 

980 for i, value in enumerate(requested): 

981 if value not in pd_index.levels[i]: 

982 raise ValueError(f"Unrecognized value {value!r} for index {index_level_names[i]!r}.") 

983 if stringify: 

984 names.append(str(requested)) 

985 else: 

986 names.append(requested) 

987 

988 return names 

989 

990 

991def _apply_astropy_metadata(astropy_table: atable.Table, arrow_schema: pa.Schema) -> None: 

992 """Apply any astropy metadata from the schema metadata. 

993 

994 Parameters 

995 ---------- 

996 astropy_table : `astropy.table.Table` 

997 Table to apply metadata. 

998 arrow_schema : `pyarrow.Schema` 

999 Arrow schema with metadata. 

1000 """ 

1001 from astropy.table import meta 

1002 

1003 metadata = arrow_schema.metadata if arrow_schema.metadata is not None else {} 

1004 

1005 # Check if we have a special astropy metadata header yaml. 

1006 meta_yaml = metadata.get(b"table_meta_yaml", None) 

1007 if meta_yaml: 

1008 meta_yaml = meta_yaml.decode("UTF8").split("\n") 

1009 meta_hdr = meta.get_header_from_yaml(meta_yaml) 

1010 

1011 # Set description, format, unit, meta from the column 

1012 # metadata that was serialized with the table. 

1013 header_cols = {x["name"]: x for x in meta_hdr["datatype"]} 

1014 for col in astropy_table.columns.values(): 

1015 for attr in ("description", "format", "unit", "meta"): 

1016 if attr in header_cols[col.name]: 

1017 setattr(col, attr, header_cols[col.name][attr]) 

1018 

1019 if "meta" in meta_hdr: 

1020 astropy_table.meta.update(meta_hdr["meta"]) 

1021 else: 

1022 # If we don't have astropy header data, we may have arrow field 

1023 # metadata. 

1024 for name in arrow_schema.names: 

1025 field_metadata = arrow_schema.field(name).metadata 

1026 if field_metadata is None: 

1027 continue 

1028 if ( 

1029 b"description" in field_metadata 

1030 and (description := field_metadata[b"description"].decode("UTF-8")) != "" 

1031 ): 

1032 astropy_table[name].description = description 

1033 if b"unit" in field_metadata and (unit := field_metadata[b"unit"].decode("UTF-8")) != "": 

1034 astropy_table[name].unit = unit 

1035 

1036 

1037def _arrow_string_to_numpy_dtype( 

1038 schema: pa.Schema, name: str, numpy_column: np.ndarray | None = None, default_length: int = 10 

1039) -> str: 

1040 """Get the numpy dtype string associated with an arrow column. 

1041 

1042 Parameters 

1043 ---------- 

1044 schema : `pyarrow.Schema` 

1045 Arrow table schema. 

1046 name : `str` 

1047 Column name. 

1048 numpy_column : `numpy.ndarray`, optional 

1049 Column to determine numpy string dtype. 

1050 default_length : `int`, optional 

1051 Default string length when not in metadata or can be inferred 

1052 from column. 

1053 

1054 Returns 

1055 ------- 

1056 dtype_str : `str` 

1057 Numpy dtype string. 

1058 """ 

1059 # Special-case for string and binary columns 

1060 md_name = f"lsst::arrow::len::{name}" 

1061 strlen = default_length 

1062 metadata = schema.metadata if schema.metadata is not None else {} 

1063 if (encoded := md_name.encode("UTF-8")) in metadata: 

1064 # String/bytes length from header. 

1065 strlen = int(schema.metadata[encoded]) 

1066 elif numpy_column is not None and len(numpy_column) > 0: 

1067 strlen = max(len(row) for row in numpy_column) 

1068 

1069 dtype = f"U{strlen}" if schema.field(name).type == pa.string() else f"|S{strlen}" 

1070 

1071 return dtype 

1072 

1073 

1074def _append_numpy_string_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None: 

1075 """Append numpy string length keys to arrow metadata. 

1076 

1077 All column types are handled, but the metadata is only modified for 

1078 string and byte columns. 

1079 

1080 Parameters 

1081 ---------- 

1082 metadata : `dict` [`bytes`, `str`] 

1083 Metadata dictionary; modified in place. 

1084 name : `str` 

1085 Column name. 

1086 dtype : `np.dtype` 

1087 Numpy dtype. 

1088 """ 

1089 import numpy as np 

1090 

1091 if dtype.type is np.str_: 

1092 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize // 4) 

1093 elif dtype.type is np.bytes_: 

1094 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize) 

1095 

1096 

1097def _append_numpy_multidim_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None: 

1098 """Append numpy multi-dimensional shapes to arrow metadata. 

1099 

1100 All column types are handled, but the metadata is only modified for 

1101 multi-dimensional columns. 

1102 

1103 Parameters 

1104 ---------- 

1105 metadata : `dict` [`bytes`, `str`] 

1106 Metadata dictionary; modified in place. 

1107 name : `str` 

1108 Column name. 

1109 dtype : `np.dtype` 

1110 Numpy dtype. 

1111 """ 

1112 if len(dtype.shape) > 1: 

1113 metadata[f"lsst::arrow::shape::{name}".encode()] = str(dtype.shape) 

1114 

1115 

1116def _multidim_shape_from_metadata(metadata: dict[bytes, bytes], list_size: int, name: str) -> tuple[int, ...]: 

1117 """Retrieve the shape from the metadata, if available. 

1118 

1119 Parameters 

1120 ---------- 

1121 metadata : `dict` [`bytes`, `bytes`] 

1122 Metadata dictionary. 

1123 list_size : `int` 

1124 Size of the list datatype. 

1125 name : `str` 

1126 Column name. 

1127 

1128 Returns 

1129 ------- 

1130 shape : `tuple` [`int`] 

1131 Shape associated with the column. 

1132 

1133 Raises 

1134 ------ 

1135 RuntimeError 

1136 Raised if metadata is found but has incorrect format. 

1137 """ 

1138 md_name = f"lsst::arrow::shape::{name}" 

1139 if (encoded := md_name.encode("UTF-8")) in metadata: 

1140 groups = re.search(r"\((.*)\)", metadata[encoded].decode("UTF-8")) 

1141 if groups is None: 

1142 raise RuntimeError("Illegal value found in metadata.") 

1143 shape = tuple(int(x) for x in groups[1].split(",") if x != "") 

1144 else: 

1145 shape = (list_size,) 

1146 

1147 return shape 

1148 

1149 

1150def _schema_to_dtype_list(schema: pa.Schema) -> list[tuple[str, tuple[Any] | str]]: 

1151 """Convert a pyarrow schema to a numpy dtype. 

1152 

1153 Parameters 

1154 ---------- 

1155 schema : `pyarrow.Schema` 

1156 Input pyarrow schema. 

1157 

1158 Returns 

1159 ------- 

1160 dtype_list: `list` [`tuple`] 

1161 A list with name, type pairs. 

1162 """ 

1163 metadata = schema.metadata if schema.metadata is not None else {} 

1164 

1165 dtype: list[Any] = [] 

1166 for name in schema.names: 

1167 t = schema.field(name).type 

1168 if isinstance(t, pa.FixedSizeListType): 

1169 shape = _multidim_shape_from_metadata(metadata, t.list_size, name) 

1170 dtype.append((name, (t.value_type.to_pandas_dtype(), shape))) 

1171 elif t not in (pa.string(), pa.binary()): 

1172 dtype.append((name, t.to_pandas_dtype())) 

1173 else: 

1174 dtype.append((name, _arrow_string_to_numpy_dtype(schema, name))) 

1175 

1176 return dtype 

1177 

1178 

1179def _numpy_dtype_to_arrow_types(dtype: np.dtype) -> list[Any]: 

1180 """Convert a numpy dtype to a list of arrow types. 

1181 

1182 Parameters 

1183 ---------- 

1184 dtype : `numpy.dtype` 

1185 Numpy dtype to convert. 

1186 

1187 Returns 

1188 ------- 

1189 type_list : `list` [`object`] 

1190 Converted list of arrow types. 

1191 """ 

1192 from math import prod 

1193 

1194 import numpy as np 

1195 

1196 type_list: list[Any] = [] 

1197 if dtype.names is None: 

1198 return type_list 

1199 

1200 for name in dtype.names: 

1201 dt = dtype[name] 

1202 arrow_type: Any 

1203 if len(dt.shape) > 0: 

1204 arrow_type = pa.list_( 

1205 pa.from_numpy_dtype(cast(tuple[np.dtype, tuple[int, ...]], dt.subdtype)[0].type), 

1206 prod(dt.shape), 

1207 ) 

1208 else: 

1209 arrow_type = pa.from_numpy_dtype(dt.type) 

1210 type_list.append((name, arrow_type)) 

1211 

1212 return type_list 

1213 

1214 

1215def _numpy_dict_to_dtype(numpy_dict: dict[str, np.ndarray]) -> tuple[np.dtype, int]: 

1216 """Extract equivalent table dtype from dict of numpy arrays. 

1217 

1218 Parameters 

1219 ---------- 

1220 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

1221 Dict with keys as the column names, values as the arrays. 

1222 

1223 Returns 

1224 ------- 

1225 dtype : `numpy.dtype` 

1226 dtype of equivalent table. 

1227 rowcount : `int` 

1228 Number of rows in the table. 

1229 

1230 Raises 

1231 ------ 

1232 ValueError if columns in numpy_dict have unequal numbers of rows. 

1233 """ 

1234 import numpy as np 

1235 

1236 dtype_list = [] 

1237 rowcount = 0 

1238 for name, col in numpy_dict.items(): 

1239 if rowcount == 0: 

1240 rowcount = len(col) 

1241 if len(col) != rowcount: 

1242 raise ValueError(f"Column {name} has a different number of rows.") 

1243 if len(col.shape) == 1: 

1244 dtype_list.append((name, col.dtype)) 

1245 else: 

1246 dtype_list.append((name, (col.dtype, col.shape[1:]))) 

1247 dtype = np.dtype(dtype_list) 

1248 

1249 return (dtype, rowcount) 

1250 

1251 

1252def _numpy_style_arrays_to_arrow_arrays( 

1253 dtype: np.dtype, 

1254 rowcount: int, 

1255 np_style_arrays: dict[str, np.ndarray] | np.ndarray | atable.Table, 

1256 schema: pa.Schema, 

1257) -> list[pa.Array]: 

1258 """Convert numpy-style arrays to arrow arrays. 

1259 

1260 Parameters 

1261 ---------- 

1262 dtype : `numpy.dtype` 

1263 Numpy dtype of input table/arrays. 

1264 rowcount : `int` 

1265 Number of rows in input table/arrays. 

1266 np_style_arrays : `dict` [`str`, `np.ndarray`] or `np.ndarray` 

1267 or `astropy.table.Table` 

1268 Arrays to convert to arrow. 

1269 schema : `pyarrow.Schema` 

1270 Schema of arrow table. 

1271 

1272 Returns 

1273 ------- 

1274 arrow_arrays : `list` [`pyarrow.Array`] 

1275 List of converted pyarrow arrays. 

1276 """ 

1277 import numpy as np 

1278 

1279 arrow_arrays: list[pa.Array] = [] 

1280 if dtype.names is None: 

1281 return arrow_arrays 

1282 

1283 for name in dtype.names: 

1284 dt = dtype[name] 

1285 val: Any 

1286 if len(dt.shape) > 0: 

1287 if rowcount > 0: 

1288 val = np.split(np_style_arrays[name].ravel(), rowcount) 

1289 else: 

1290 val = [] 

1291 else: 

1292 val = np_style_arrays[name] 

1293 

1294 try: 

1295 arrow_arrays.append(pa.array(val, type=schema.field(name).type)) 

1296 except pa.ArrowNotImplementedError as err: 

1297 # Check if val is big-endian. 

1298 if (np.little_endian and val.dtype.byteorder == ">") or ( 

1299 not np.little_endian and val.dtype.byteorder == "=" 

1300 ): 

1301 # We need to convert the array to little-endian. 

1302 val2 = val.byteswap() 

1303 val2.dtype = val2.dtype.newbyteorder("<") 

1304 arrow_arrays.append(pa.array(val2, type=schema.field(name).type)) 

1305 else: 

1306 # This failed for some other reason so raise the exception. 

1307 raise err 

1308 

1309 return arrow_arrays 

1310 

1311 

1312def compute_row_group_size(schema: pa.Schema, target_size: int = TARGET_ROW_GROUP_BYTES) -> int: 

1313 """Compute approximate row group size for a given arrow schema. 

1314 

1315 Given a schema, this routine will compute the number of rows in a row group 

1316 that targets the persisted size on disk (or smaller). The exact size on 

1317 disk depends on the compression settings and ratios; typical binary data 

1318 tables will have around 15-20% compression with the pyarrow default 

1319 ``snappy`` compression algorithm. 

1320 

1321 Parameters 

1322 ---------- 

1323 schema : `pyarrow.Schema` 

1324 Arrow table schema. 

1325 target_size : `int`, optional 

1326 The target size (in bytes). 

1327 

1328 Returns 

1329 ------- 

1330 row_group_size : `int` 

1331 Number of rows per row group to hit the target size. 

1332 """ 

1333 bit_width = 0 

1334 

1335 metadata = schema.metadata if schema.metadata is not None else {} 

1336 

1337 for name in schema.names: 

1338 t = schema.field(name).type 

1339 

1340 if t in (pa.string(), pa.binary()): 

1341 md_name = f"lsst::arrow::len::{name}" 

1342 

1343 if (encoded := md_name.encode("UTF-8")) in metadata: 

1344 # String/bytes length from header. 

1345 strlen = int(schema.metadata[encoded]) 

1346 else: 

1347 # We don't know the string width, so guess something. 

1348 strlen = 10 

1349 

1350 # Assuming UTF-8 encoding, and very few wide characters. 

1351 t_width = 8 * strlen 

1352 elif isinstance(t, pa.FixedSizeListType): 

1353 if t.value_type == pa.null(): 

1354 t_width = 0 

1355 else: 

1356 t_width = t.list_size * t.value_type.bit_width 

1357 elif t == pa.null(): 

1358 t_width = 0 

1359 elif isinstance(t, pa.ListType): 

1360 if t.value_type == pa.null(): 

1361 t_width = 0 

1362 else: 

1363 # This is a variable length list, just choose 

1364 # something arbitrary. 

1365 t_width = 10 * t.value_type.bit_width 

1366 else: 

1367 t_width = t.bit_width 

1368 

1369 bit_width += t_width 

1370 

1371 # Insist it is at least 1 byte wide to avoid any divide-by-zero errors. 

1372 if bit_width < 8: 

1373 bit_width = 8 

1374 

1375 byte_width = bit_width // 8 

1376 

1377 return target_size // byte_width