Coverage for python/lsst/daf/butler/formatters/parquet.py: 13%

473 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-12 10:07 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ( 

31 "ParquetFormatter", 

32 "arrow_to_pandas", 

33 "arrow_to_astropy", 

34 "arrow_to_numpy", 

35 "arrow_to_numpy_dict", 

36 "pandas_to_arrow", 

37 "pandas_to_astropy", 

38 "astropy_to_arrow", 

39 "numpy_to_arrow", 

40 "numpy_to_astropy", 

41 "numpy_dict_to_arrow", 

42 "arrow_schema_to_pandas_index", 

43 "DataFrameSchema", 

44 "ArrowAstropySchema", 

45 "ArrowNumpySchema", 

46 "compute_row_group_size", 

47) 

48 

49import collections.abc 

50import itertools 

51import json 

52import re 

53from collections.abc import Iterable, Sequence 

54from typing import TYPE_CHECKING, Any, cast 

55 

56import pyarrow as pa 

57import pyarrow.parquet as pq 

58from lsst.daf.butler import Formatter 

59from lsst.utils.introspection import get_full_type_name 

60from lsst.utils.iteration import ensure_iterable 

61 

62if TYPE_CHECKING: 

63 import astropy.table as atable 

64 import numpy as np 

65 import pandas as pd 

66 

67TARGET_ROW_GROUP_BYTES = 1_000_000_000 

68 

69 

70class ParquetFormatter(Formatter): 

71 """Interface for reading and writing Arrow Table objects to and from 

72 Parquet files. 

73 """ 

74 

75 extension = ".parq" 

76 

77 def read(self, component: str | None = None) -> Any: 

78 # Docstring inherited from Formatter.read. 

79 schema = pq.read_schema(self.fileDescriptor.location.path) 

80 

81 schema_names = ["ArrowSchema", "DataFrameSchema", "ArrowAstropySchema", "ArrowNumpySchema"] 

82 

83 if component in ("columns", "schema") or self.fileDescriptor.readStorageClass.name in schema_names: 

84 # The schema will be translated to column format 

85 # depending on the input type. 

86 return schema 

87 elif component == "rowcount": 

88 # Get the rowcount from the metadata if possible, otherwise count. 

89 if b"lsst::arrow::rowcount" in schema.metadata: 

90 return int(schema.metadata[b"lsst::arrow::rowcount"]) 

91 

92 temp_table = pq.read_table( 

93 self.fileDescriptor.location.path, 

94 columns=[schema.names[0]], 

95 use_threads=False, 

96 use_pandas_metadata=False, 

97 ) 

98 

99 return len(temp_table[schema.names[0]]) 

100 

101 par_columns = None 

102 if self.fileDescriptor.parameters: 

103 par_columns = self.fileDescriptor.parameters.pop("columns", None) 

104 if par_columns: 

105 has_pandas_multi_index = False 

106 if b"pandas" in schema.metadata: 

107 md = json.loads(schema.metadata[b"pandas"]) 

108 if len(md["column_indexes"]) > 1: 

109 has_pandas_multi_index = True 

110 

111 if not has_pandas_multi_index: 

112 # Ensure uniqueness, keeping order. 

113 par_columns = list(dict.fromkeys(ensure_iterable(par_columns))) 

114 file_columns = [name for name in schema.names if not name.startswith("__")] 

115 

116 for par_column in par_columns: 

117 if par_column not in file_columns: 

118 raise ValueError( 

119 f"Column {par_column} specified in parameters not available in parquet file." 

120 ) 

121 else: 

122 par_columns = _standardize_multi_index_columns( 

123 arrow_schema_to_pandas_index(schema), 

124 par_columns, 

125 ) 

126 

127 if len(self.fileDescriptor.parameters): 

128 raise ValueError( 

129 f"Unsupported parameters {self.fileDescriptor.parameters} in ArrowTable read." 

130 ) 

131 

132 metadata = schema.metadata if schema.metadata is not None else {} 

133 arrow_table = pq.read_table( 

134 self.fileDescriptor.location.path, 

135 columns=par_columns, 

136 use_threads=False, 

137 use_pandas_metadata=(b"pandas" in metadata), 

138 ) 

139 

140 return arrow_table 

141 

142 def write(self, inMemoryDataset: Any) -> None: 

143 import numpy as np 

144 from astropy.table import Table as astropyTable 

145 

146 location = self.makeUpdatedLocation(self.fileDescriptor.location) 

147 

148 arrow_table = None 

149 if isinstance(inMemoryDataset, pa.Table): 

150 # This will be the most likely match. 

151 arrow_table = inMemoryDataset 

152 elif isinstance(inMemoryDataset, astropyTable): 

153 arrow_table = astropy_to_arrow(inMemoryDataset) 

154 elif isinstance(inMemoryDataset, np.ndarray): 

155 arrow_table = numpy_to_arrow(inMemoryDataset) 

156 elif isinstance(inMemoryDataset, dict): 

157 try: 

158 arrow_table = numpy_dict_to_arrow(inMemoryDataset) 

159 except (TypeError, AttributeError) as e: 

160 raise ValueError( 

161 "Input dict for inMemoryDataset does not appear to be a dict of numpy arrays." 

162 ) from e 

163 elif isinstance(inMemoryDataset, pa.Schema): 

164 pq.write_metadata(inMemoryDataset, location.path) 

165 return 

166 else: 

167 if hasattr(inMemoryDataset, "to_parquet"): 

168 # This may be a pandas DataFrame 

169 try: 

170 import pandas as pd 

171 except ImportError: 

172 pd = None 

173 

174 if pd is not None and isinstance(inMemoryDataset, pd.DataFrame): 

175 arrow_table = pandas_to_arrow(inMemoryDataset) 

176 

177 if arrow_table is None: 

178 raise ValueError( 

179 f"Unsupported type {get_full_type_name(inMemoryDataset)} of " 

180 "inMemoryDataset for ParquetFormatter." 

181 ) 

182 

183 row_group_size = compute_row_group_size(arrow_table.schema) 

184 

185 pq.write_table(arrow_table, location.path, row_group_size=row_group_size) 

186 

187 

188def arrow_to_pandas(arrow_table: pa.Table) -> pd.DataFrame: 

189 """Convert a pyarrow table to a pandas DataFrame. 

190 

191 Parameters 

192 ---------- 

193 arrow_table : `pyarrow.Table` 

194 Input arrow table to convert. If the table has ``pandas`` metadata 

195 in the schema it will be used in the construction of the 

196 ``DataFrame``. 

197 

198 Returns 

199 ------- 

200 dataframe : `pandas.DataFrame` 

201 Converted pandas dataframe. 

202 """ 

203 return arrow_table.to_pandas(use_threads=False, integer_object_nulls=True) 

204 

205 

206def arrow_to_astropy(arrow_table: pa.Table) -> atable.Table: 

207 """Convert a pyarrow table to an `astropy.Table`. 

208 

209 Parameters 

210 ---------- 

211 arrow_table : `pyarrow.Table` 

212 Input arrow table to convert. If the table has astropy unit 

213 metadata in the schema it will be used in the construction 

214 of the ``astropy.Table``. 

215 

216 Returns 

217 ------- 

218 table : `astropy.Table` 

219 Converted astropy table. 

220 """ 

221 from astropy.table import Table 

222 

223 astropy_table = Table(arrow_to_numpy_dict(arrow_table)) 

224 

225 _apply_astropy_metadata(astropy_table, arrow_table.schema) 

226 

227 return astropy_table 

228 

229 

230def arrow_to_numpy(arrow_table: pa.Table) -> np.ndarray: 

231 """Convert a pyarrow table to a structured numpy array. 

232 

233 Parameters 

234 ---------- 

235 arrow_table : `pyarrow.Table` 

236 Input arrow table. 

237 

238 Returns 

239 ------- 

240 array : `numpy.ndarray` (N,) 

241 Numpy array table with N rows and the same column names 

242 as the input arrow table. 

243 """ 

244 import numpy as np 

245 

246 numpy_dict = arrow_to_numpy_dict(arrow_table) 

247 

248 dtype = [] 

249 for name, col in numpy_dict.items(): 

250 if len(shape := numpy_dict[name].shape) <= 1: 

251 dtype.append((name, col.dtype)) 

252 else: 

253 dtype.append((name, (col.dtype, shape[1:]))) 

254 

255 array = np.rec.fromarrays(numpy_dict.values(), dtype=dtype) 

256 

257 return array 

258 

259 

260def arrow_to_numpy_dict(arrow_table: pa.Table) -> dict[str, np.ndarray]: 

261 """Convert a pyarrow table to a dict of numpy arrays. 

262 

263 Parameters 

264 ---------- 

265 arrow_table : `pyarrow.Table` 

266 Input arrow table. 

267 

268 Returns 

269 ------- 

270 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

271 Dict with keys as the column names, values as the arrays. 

272 """ 

273 import numpy as np 

274 

275 schema = arrow_table.schema 

276 metadata = schema.metadata if schema.metadata is not None else {} 

277 

278 numpy_dict = {} 

279 

280 for name in schema.names: 

281 t = schema.field(name).type 

282 

283 if arrow_table[name].null_count == 0: 

284 # Regular non-masked column 

285 col = arrow_table[name].to_numpy() 

286 else: 

287 # For a masked column, we need to ask arrow to fill the null 

288 # values with an appropriately typed value before conversion. 

289 # Then we apply the mask to get a masked array of the correct type. 

290 use_masked = True 

291 null_value: Any 

292 match t: 

293 case t if t in (pa.float64(), pa.float32(), pa.float16()): 

294 # When filling with nans we do not need to use 

295 # the masked array. 

296 null_value = np.nan 

297 use_masked = False 

298 case t if t in (pa.int64(), pa.int32(), pa.int16(), pa.int8()): 

299 null_value = -1 

300 case t if t in (pa.bool_(),): 

301 null_value = True 

302 case t if t in (pa.string(), pa.binary()): 

303 null_value = "" 

304 case _: 

305 # This is the fallback for unsigned ints in particular. 

306 null_value = 0 

307 

308 if use_masked: 

309 col = np.ma.masked_array( 

310 data=arrow_table[name].fill_null(null_value).to_numpy(), 

311 mask=arrow_table[name].is_null().to_numpy(), 

312 fill_value=null_value, 

313 ) 

314 else: 

315 col = arrow_table[name].fill_null(null_value).to_numpy() 

316 

317 if t in (pa.string(), pa.binary()): 

318 col = col.astype(_arrow_string_to_numpy_dtype(schema, name, col)) 

319 elif isinstance(t, pa.FixedSizeListType): 

320 if len(col) > 0: 

321 col = np.stack(col) 

322 else: 

323 # this is an empty column, and needs to be coerced to type. 

324 col = col.astype(t.value_type.to_pandas_dtype()) 

325 

326 shape = _multidim_shape_from_metadata(metadata, t.list_size, name) 

327 col = col.reshape((len(arrow_table), *shape)) 

328 

329 numpy_dict[name] = col 

330 

331 return numpy_dict 

332 

333 

334def _numpy_dict_to_numpy(numpy_dict: dict[str, np.ndarray]) -> np.ndarray: 

335 """Convert a dict of numpy arrays to a structured numpy array. 

336 

337 Parameters 

338 ---------- 

339 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

340 Dict with keys as the column names, values as the arrays. 

341 

342 Returns 

343 ------- 

344 array : `numpy.ndarray` (N,) 

345 Numpy array table with N rows and columns names from the dict keys. 

346 """ 

347 return arrow_to_numpy(numpy_dict_to_arrow(numpy_dict)) 

348 

349 

350def _numpy_to_numpy_dict(np_array: np.ndarray) -> dict[str, np.ndarray]: 

351 """Convert a structured numpy array to a dict of numpy arrays. 

352 

353 Parameters 

354 ---------- 

355 np_array : `numpy.ndarray` 

356 Input numpy array with multiple fields. 

357 

358 Returns 

359 ------- 

360 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

361 Dict with keys as the column names, values as the arrays. 

362 """ 

363 return arrow_to_numpy_dict(numpy_to_arrow(np_array)) 

364 

365 

366def numpy_to_arrow(np_array: np.ndarray) -> pa.Table: 

367 """Convert a numpy array table to an arrow table. 

368 

369 Parameters 

370 ---------- 

371 np_array : `numpy.ndarray` 

372 Input numpy array with multiple fields. 

373 

374 Returns 

375 ------- 

376 arrow_table : `pyarrow.Table` 

377 Converted arrow table. 

378 """ 

379 type_list = _numpy_dtype_to_arrow_types(np_array.dtype) 

380 

381 md = {} 

382 md[b"lsst::arrow::rowcount"] = str(len(np_array)) 

383 

384 for name in np_array.dtype.names: 

385 _append_numpy_string_metadata(md, name, np_array.dtype[name]) 

386 _append_numpy_multidim_metadata(md, name, np_array.dtype[name]) 

387 

388 schema = pa.schema(type_list, metadata=md) 

389 

390 arrays = _numpy_style_arrays_to_arrow_arrays( 

391 np_array.dtype, 

392 len(np_array), 

393 np_array, 

394 schema, 

395 ) 

396 

397 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

398 

399 return arrow_table 

400 

401 

402def numpy_dict_to_arrow(numpy_dict: dict[str, np.ndarray]) -> pa.Table: 

403 """Convert a dict of numpy arrays to an arrow table. 

404 

405 Parameters 

406 ---------- 

407 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

408 Dict with keys as the column names, values as the arrays. 

409 

410 Returns 

411 ------- 

412 arrow_table : `pyarrow.Table` 

413 Converted arrow table. 

414 

415 Raises 

416 ------ 

417 ValueError if columns in numpy_dict have unequal numbers of rows. 

418 """ 

419 dtype, rowcount = _numpy_dict_to_dtype(numpy_dict) 

420 type_list = _numpy_dtype_to_arrow_types(dtype) 

421 

422 md = {} 

423 md[b"lsst::arrow::rowcount"] = str(rowcount) 

424 

425 if dtype.names is not None: 

426 for name in dtype.names: 

427 _append_numpy_string_metadata(md, name, dtype[name]) 

428 _append_numpy_multidim_metadata(md, name, dtype[name]) 

429 

430 schema = pa.schema(type_list, metadata=md) 

431 

432 arrays = _numpy_style_arrays_to_arrow_arrays( 

433 dtype, 

434 rowcount, 

435 numpy_dict, 

436 schema, 

437 ) 

438 

439 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

440 

441 return arrow_table 

442 

443 

444def astropy_to_arrow(astropy_table: atable.Table) -> pa.Table: 

445 """Convert an astropy table to an arrow table. 

446 

447 Parameters 

448 ---------- 

449 astropy_table : `astropy.Table` 

450 Input astropy table. 

451 

452 Returns 

453 ------- 

454 arrow_table : `pyarrow.Table` 

455 Converted arrow table. 

456 """ 

457 from astropy.table import meta 

458 

459 type_list = _numpy_dtype_to_arrow_types(astropy_table.dtype) 

460 

461 md = {} 

462 md[b"lsst::arrow::rowcount"] = str(len(astropy_table)) 

463 

464 for name in astropy_table.dtype.names: 

465 _append_numpy_string_metadata(md, name, astropy_table.dtype[name]) 

466 _append_numpy_multidim_metadata(md, name, astropy_table.dtype[name]) 

467 

468 meta_yaml = meta.get_yaml_from_table(astropy_table) 

469 meta_yaml_str = "\n".join(meta_yaml) 

470 md[b"table_meta_yaml"] = meta_yaml_str 

471 

472 # Convert type list to fields with metadata. 

473 fields = [] 

474 for name, pa_type in type_list: 

475 field_metadata = {} 

476 if description := astropy_table[name].description: 

477 field_metadata["description"] = description 

478 if unit := astropy_table[name].unit: 

479 field_metadata["unit"] = str(unit) 

480 fields.append( 

481 pa.field( 

482 name, 

483 pa_type, 

484 metadata=field_metadata, 

485 ) 

486 ) 

487 

488 schema = pa.schema(fields, metadata=md) 

489 

490 arrays = _numpy_style_arrays_to_arrow_arrays( 

491 astropy_table.dtype, 

492 len(astropy_table), 

493 astropy_table, 

494 schema, 

495 ) 

496 

497 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

498 

499 return arrow_table 

500 

501 

502def _astropy_to_numpy_dict(astropy_table: atable.Table) -> dict[str, np.ndarray]: 

503 """Convert an astropy table to an arrow table. 

504 

505 Parameters 

506 ---------- 

507 astropy_table : `astropy.Table` 

508 Input astropy table. 

509 

510 Returns 

511 ------- 

512 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

513 Dict with keys as the column names, values as the arrays. 

514 """ 

515 return arrow_to_numpy_dict(astropy_to_arrow(astropy_table)) 

516 

517 

518def pandas_to_arrow(dataframe: pd.DataFrame, default_length: int = 10) -> pa.Table: 

519 """Convert a pandas dataframe to an arrow table. 

520 

521 Parameters 

522 ---------- 

523 dataframe : `pandas.DataFrame` 

524 Input pandas dataframe. 

525 default_length : `int`, optional 

526 Default string length when not in metadata or can be inferred 

527 from column. 

528 

529 Returns 

530 ------- 

531 arrow_table : `pyarrow.Table` 

532 Converted arrow table. 

533 """ 

534 arrow_table = pa.Table.from_pandas(dataframe) 

535 

536 # Update the metadata 

537 md = arrow_table.schema.metadata 

538 

539 md[b"lsst::arrow::rowcount"] = str(arrow_table.num_rows) 

540 

541 # We loop through the arrow table columns because the datatypes have 

542 # been checked and converted from pandas objects. 

543 for name in arrow_table.column_names: 

544 if not name.startswith("__") and arrow_table[name].type == pa.string(): 

545 if len(arrow_table[name]) > 0: 

546 strlen = max(len(row.as_py()) for row in arrow_table[name] if row.is_valid) 

547 else: 

548 strlen = default_length 

549 md[f"lsst::arrow::len::{name}".encode()] = str(strlen) 

550 

551 arrow_table = arrow_table.replace_schema_metadata(md) 

552 

553 return arrow_table 

554 

555 

556def pandas_to_astropy(dataframe: pd.DataFrame) -> atable.Table: 

557 """Convert a pandas dataframe to an astropy table, preserving indexes. 

558 

559 Parameters 

560 ---------- 

561 dataframe : `pandas.DataFrame` 

562 Input pandas dataframe. 

563 

564 Returns 

565 ------- 

566 astropy_table : `astropy.table.Table` 

567 Converted astropy table. 

568 """ 

569 import pandas as pd 

570 from astropy.table import Table 

571 

572 if isinstance(dataframe.columns, pd.MultiIndex): 

573 raise ValueError("Cannot convert a multi-index dataframe to an astropy table.") 

574 

575 return Table.from_pandas(dataframe, index=True) 

576 

577 

578def _pandas_to_numpy_dict(dataframe: pd.DataFrame) -> dict[str, np.ndarray]: 

579 """Convert a pandas dataframe to an dict of numpy arrays. 

580 

581 Parameters 

582 ---------- 

583 dataframe : `pandas.DataFrame` 

584 Input pandas dataframe. 

585 

586 Returns 

587 ------- 

588 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

589 Dict with keys as the column names, values as the arrays. 

590 """ 

591 return arrow_to_numpy_dict(pandas_to_arrow(dataframe)) 

592 

593 

594def numpy_to_astropy(np_array: np.ndarray) -> atable.Table: 

595 """Convert a numpy table to an astropy table. 

596 

597 Parameters 

598 ---------- 

599 np_array : `numpy.ndarray` 

600 Input numpy array with multiple fields. 

601 

602 Returns 

603 ------- 

604 astropy_table : `astropy.table.Table` 

605 Converted astropy table. 

606 """ 

607 from astropy.table import Table 

608 

609 return Table(data=np_array, copy=False) 

610 

611 

612def arrow_schema_to_pandas_index(schema: pa.Schema) -> pd.Index | pd.MultiIndex: 

613 """Convert an arrow schema to a pandas index/multiindex. 

614 

615 Parameters 

616 ---------- 

617 schema : `pyarrow.Schema` 

618 Input pyarrow schema. 

619 

620 Returns 

621 ------- 

622 index : `pandas.Index` or `pandas.MultiIndex` 

623 Converted pandas index. 

624 """ 

625 import pandas as pd 

626 

627 if b"pandas" in schema.metadata: 

628 md = json.loads(schema.metadata[b"pandas"]) 

629 indexes = md["column_indexes"] 

630 len_indexes = len(indexes) 

631 else: 

632 len_indexes = 0 

633 

634 if len_indexes <= 1: 

635 return pd.Index(name for name in schema.names if not name.startswith("__")) 

636 else: 

637 raw_columns = _split_multi_index_column_names(len(indexes), schema.names) 

638 return pd.MultiIndex.from_tuples(raw_columns, names=[f["name"] for f in indexes]) 

639 

640 

641def arrow_schema_to_column_list(schema: pa.Schema) -> list[str]: 

642 """Convert an arrow schema to a list of string column names. 

643 

644 Parameters 

645 ---------- 

646 schema : `pyarrow.Schema` 

647 Input pyarrow schema. 

648 

649 Returns 

650 ------- 

651 column_list : `list` [`str`] 

652 Converted list of column names. 

653 """ 

654 return list(schema.names) 

655 

656 

657class DataFrameSchema: 

658 """Wrapper class for a schema for a pandas DataFrame. 

659 

660 Parameters 

661 ---------- 

662 dataframe : `pandas.DataFrame` 

663 Dataframe to turn into a schema. 

664 """ 

665 

666 def __init__(self, dataframe: pd.DataFrame) -> None: 

667 self._schema = dataframe.loc[[False] * len(dataframe)] 

668 

669 @classmethod 

670 def from_arrow(cls, schema: pa.Schema) -> DataFrameSchema: 

671 """Convert an arrow schema into a `DataFrameSchema`. 

672 

673 Parameters 

674 ---------- 

675 schema : `pyarrow.Schema` 

676 The pyarrow schema to convert. 

677 

678 Returns 

679 ------- 

680 dataframe_schema : `DataFrameSchema` 

681 Converted dataframe schema. 

682 """ 

683 empty_table = pa.Table.from_pylist([] * len(schema.names), schema=schema) 

684 

685 return cls(empty_table.to_pandas()) 

686 

687 def to_arrow_schema(self) -> pa.Schema: 

688 """Convert to an arrow schema. 

689 

690 Returns 

691 ------- 

692 arrow_schema : `pyarrow.Schema` 

693 Converted pyarrow schema. 

694 """ 

695 arrow_table = pa.Table.from_pandas(self._schema) 

696 

697 return arrow_table.schema 

698 

699 def to_arrow_numpy_schema(self) -> ArrowNumpySchema: 

700 """Convert to an `ArrowNumpySchema`. 

701 

702 Returns 

703 ------- 

704 arrow_numpy_schema : `ArrowNumpySchema` 

705 Converted arrow numpy schema. 

706 """ 

707 return ArrowNumpySchema.from_arrow(self.to_arrow_schema()) 

708 

709 def to_arrow_astropy_schema(self) -> ArrowAstropySchema: 

710 """Convert to an ArrowAstropySchema. 

711 

712 Returns 

713 ------- 

714 arrow_astropy_schema : `ArrowAstropySchema` 

715 Converted arrow astropy schema. 

716 """ 

717 return ArrowAstropySchema.from_arrow(self.to_arrow_schema()) 

718 

719 @property 

720 def schema(self) -> np.dtype: 

721 return self._schema 

722 

723 def __repr__(self) -> str: 

724 return repr(self._schema) 

725 

726 def __eq__(self, other: object) -> bool: 

727 if not isinstance(other, DataFrameSchema): 

728 return NotImplemented 

729 

730 return self._schema.equals(other._schema) 

731 

732 

733class ArrowAstropySchema: 

734 """Wrapper class for a schema for an astropy table. 

735 

736 Parameters 

737 ---------- 

738 astropy_table : `astropy.table.Table` 

739 Input astropy table. 

740 """ 

741 

742 def __init__(self, astropy_table: atable.Table) -> None: 

743 self._schema = astropy_table[:0] 

744 

745 @classmethod 

746 def from_arrow(cls, schema: pa.Schema) -> ArrowAstropySchema: 

747 """Convert an arrow schema into a ArrowAstropySchema. 

748 

749 Parameters 

750 ---------- 

751 schema : `pyarrow.Schema` 

752 Input pyarrow schema. 

753 

754 Returns 

755 ------- 

756 astropy_schema : `ArrowAstropySchema` 

757 Converted arrow astropy schema. 

758 """ 

759 import numpy as np 

760 from astropy.table import Table 

761 

762 dtype = _schema_to_dtype_list(schema) 

763 

764 data = np.zeros(0, dtype=dtype) 

765 astropy_table = Table(data=data) 

766 

767 _apply_astropy_metadata(astropy_table, schema) 

768 

769 return cls(astropy_table) 

770 

771 def to_arrow_schema(self) -> pa.Schema: 

772 """Convert to an arrow schema. 

773 

774 Returns 

775 ------- 

776 arrow_schema : `pyarrow.Schema` 

777 Converted pyarrow schema. 

778 """ 

779 return astropy_to_arrow(self._schema).schema 

780 

781 def to_dataframe_schema(self) -> DataFrameSchema: 

782 """Convert to a DataFrameSchema. 

783 

784 Returns 

785 ------- 

786 dataframe_schema : `DataFrameSchema` 

787 Converted dataframe schema. 

788 """ 

789 return DataFrameSchema.from_arrow(astropy_to_arrow(self._schema).schema) 

790 

791 def to_arrow_numpy_schema(self) -> ArrowNumpySchema: 

792 """Convert to an `ArrowNumpySchema`. 

793 

794 Returns 

795 ------- 

796 arrow_numpy_schema : `ArrowNumpySchema` 

797 Converted arrow numpy schema. 

798 """ 

799 return ArrowNumpySchema.from_arrow(astropy_to_arrow(self._schema).schema) 

800 

801 @property 

802 def schema(self) -> atable.Table: 

803 return self._schema 

804 

805 def __repr__(self) -> str: 

806 return repr(self._schema) 

807 

808 def __eq__(self, other: object) -> bool: 

809 if not isinstance(other, ArrowAstropySchema): 

810 return NotImplemented 

811 

812 # If this comparison passes then the two tables have the 

813 # same column names. 

814 if self._schema.dtype != other._schema.dtype: 

815 return False 

816 

817 for name in self._schema.columns: 

818 if not self._schema[name].unit == other._schema[name].unit: 

819 return False 

820 if not self._schema[name].description == other._schema[name].description: 

821 return False 

822 if not self._schema[name].format == other._schema[name].format: 

823 return False 

824 

825 return True 

826 

827 

828class ArrowNumpySchema: 

829 """Wrapper class for a schema for a numpy ndarray. 

830 

831 Parameters 

832 ---------- 

833 numpy_dtype : `numpy.dtype` 

834 Numpy dtype to convert. 

835 """ 

836 

837 def __init__(self, numpy_dtype: np.dtype) -> None: 

838 self._dtype = numpy_dtype 

839 

840 @classmethod 

841 def from_arrow(cls, schema: pa.Schema) -> ArrowNumpySchema: 

842 """Convert an arrow schema into an `ArrowNumpySchema`. 

843 

844 Parameters 

845 ---------- 

846 schema : `pyarrow.Schema` 

847 Pyarrow schema to convert. 

848 

849 Returns 

850 ------- 

851 numpy_schema : `ArrowNumpySchema` 

852 Converted arrow numpy schema. 

853 """ 

854 import numpy as np 

855 

856 dtype = _schema_to_dtype_list(schema) 

857 

858 return cls(np.dtype(dtype)) 

859 

860 def to_arrow_astropy_schema(self) -> ArrowAstropySchema: 

861 """Convert to an `ArrowAstropySchema`. 

862 

863 Returns 

864 ------- 

865 astropy_schema : `ArrowAstropySchema` 

866 Converted arrow astropy schema. 

867 """ 

868 import numpy as np 

869 

870 return ArrowAstropySchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema) 

871 

872 def to_dataframe_schema(self) -> DataFrameSchema: 

873 """Convert to a `DataFrameSchema`. 

874 

875 Returns 

876 ------- 

877 dataframe_schema : `DataFrameSchema` 

878 Converted dataframe schema. 

879 """ 

880 import numpy as np 

881 

882 return DataFrameSchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema) 

883 

884 def to_arrow_schema(self) -> pa.Schema: 

885 """Convert to a `pyarrow.Schema`. 

886 

887 Returns 

888 ------- 

889 arrow_schema : `pyarrow.Schema` 

890 Converted pyarrow schema. 

891 """ 

892 import numpy as np 

893 

894 return numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema 

895 

896 @property 

897 def schema(self) -> np.dtype: 

898 return self._dtype 

899 

900 def __repr__(self) -> str: 

901 return repr(self._dtype) 

902 

903 def __eq__(self, other: object) -> bool: 

904 if not isinstance(other, ArrowNumpySchema): 

905 return NotImplemented 

906 

907 if not self._dtype == other._dtype: 

908 return False 

909 

910 return True 

911 

912 

913def _split_multi_index_column_names(n: int, names: Iterable[str]) -> list[Sequence[str]]: 

914 """Split a string that represents a multi-index column. 

915 

916 PyArrow maps Pandas' multi-index column names (which are tuples in Python) 

917 to flat strings on disk. This routine exists to reconstruct the original 

918 tuple. 

919 

920 Parameters 

921 ---------- 

922 n : `int` 

923 Number of levels in the `pandas.MultiIndex` that is being 

924 reconstructed. 

925 names : `~collections.abc.Iterable` [`str`] 

926 Strings to be split. 

927 

928 Returns 

929 ------- 

930 column_names : `list` [`tuple` [`str`]] 

931 A list of multi-index column name tuples. 

932 """ 

933 column_names: list[Sequence[str]] = [] 

934 

935 pattern = re.compile(r"\({}\)".format(", ".join(["'(.*)'"] * n))) 

936 for name in names: 

937 m = re.search(pattern, name) 

938 if m is not None: 

939 column_names.append(m.groups()) 

940 

941 return column_names 

942 

943 

944def _standardize_multi_index_columns( 

945 pd_index: pd.MultiIndex, 

946 columns: Any, 

947 stringify: bool = True, 

948) -> list[str | Sequence[Any]]: 

949 """Transform a dictionary/iterable index from a multi-index column list 

950 into a string directly understandable by PyArrow. 

951 

952 Parameters 

953 ---------- 

954 pd_index : `pandas.MultiIndex` 

955 Pandas multi-index. 

956 columns : `list` [`tuple`] or `dict` [`str`, `str` or `list` [`str`]] 

957 Columns to standardize. 

958 stringify : `bool`, optional 

959 Should the column names be stringified? 

960 

961 Returns 

962 ------- 

963 names : `list` [`str`] 

964 Stringified representation of a multi-index column name. 

965 """ 

966 index_level_names = tuple(pd_index.names) 

967 

968 names: list[str | Sequence[Any]] = [] 

969 

970 if isinstance(columns, list): 

971 for requested in columns: 

972 if not isinstance(requested, tuple): 

973 raise ValueError( 

974 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. " 

975 f"Instead got a {get_full_type_name(requested)}." 

976 ) 

977 if stringify: 

978 names.append(str(requested)) 

979 else: 

980 names.append(requested) 

981 else: 

982 if not isinstance(columns, collections.abc.Mapping): 

983 raise ValueError( 

984 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. " 

985 f"Instead got a {get_full_type_name(columns)}." 

986 ) 

987 if not set(index_level_names).issuperset(columns.keys()): 

988 raise ValueError( 

989 f"Cannot use dict with keys {set(columns.keys())} to select columns from {index_level_names}." 

990 ) 

991 factors = [ 

992 ensure_iterable(columns.get(level, pd_index.levels[i])) 

993 for i, level in enumerate(index_level_names) 

994 ] 

995 for requested in itertools.product(*factors): 

996 for i, value in enumerate(requested): 

997 if value not in pd_index.levels[i]: 

998 raise ValueError(f"Unrecognized value {value!r} for index {index_level_names[i]!r}.") 

999 if stringify: 

1000 names.append(str(requested)) 

1001 else: 

1002 names.append(requested) 

1003 

1004 return names 

1005 

1006 

1007def _apply_astropy_metadata(astropy_table: atable.Table, arrow_schema: pa.Schema) -> None: 

1008 """Apply any astropy metadata from the schema metadata. 

1009 

1010 Parameters 

1011 ---------- 

1012 astropy_table : `astropy.table.Table` 

1013 Table to apply metadata. 

1014 arrow_schema : `pyarrow.Schema` 

1015 Arrow schema with metadata. 

1016 """ 

1017 from astropy.table import meta 

1018 

1019 metadata = arrow_schema.metadata if arrow_schema.metadata is not None else {} 

1020 

1021 # Check if we have a special astropy metadata header yaml. 

1022 meta_yaml = metadata.get(b"table_meta_yaml", None) 

1023 if meta_yaml: 

1024 meta_yaml = meta_yaml.decode("UTF8").split("\n") 

1025 meta_hdr = meta.get_header_from_yaml(meta_yaml) 

1026 

1027 # Set description, format, unit, meta from the column 

1028 # metadata that was serialized with the table. 

1029 header_cols = {x["name"]: x for x in meta_hdr["datatype"]} 

1030 for col in astropy_table.columns.values(): 

1031 for attr in ("description", "format", "unit", "meta"): 

1032 if attr in header_cols[col.name]: 

1033 setattr(col, attr, header_cols[col.name][attr]) 

1034 

1035 if "meta" in meta_hdr: 

1036 astropy_table.meta.update(meta_hdr["meta"]) 

1037 else: 

1038 # If we don't have astropy header data, we may have arrow field 

1039 # metadata. 

1040 for name in arrow_schema.names: 

1041 field_metadata = arrow_schema.field(name).metadata 

1042 if field_metadata is None: 

1043 continue 

1044 if ( 

1045 b"description" in field_metadata 

1046 and (description := field_metadata[b"description"].decode("UTF-8")) != "" 

1047 ): 

1048 astropy_table[name].description = description 

1049 if b"unit" in field_metadata and (unit := field_metadata[b"unit"].decode("UTF-8")) != "": 

1050 astropy_table[name].unit = unit 

1051 

1052 

1053def _arrow_string_to_numpy_dtype( 

1054 schema: pa.Schema, name: str, numpy_column: np.ndarray | None = None, default_length: int = 10 

1055) -> str: 

1056 """Get the numpy dtype string associated with an arrow column. 

1057 

1058 Parameters 

1059 ---------- 

1060 schema : `pyarrow.Schema` 

1061 Arrow table schema. 

1062 name : `str` 

1063 Column name. 

1064 numpy_column : `numpy.ndarray`, optional 

1065 Column to determine numpy string dtype. 

1066 default_length : `int`, optional 

1067 Default string length when not in metadata or can be inferred 

1068 from column. 

1069 

1070 Returns 

1071 ------- 

1072 dtype_str : `str` 

1073 Numpy dtype string. 

1074 """ 

1075 # Special-case for string and binary columns 

1076 md_name = f"lsst::arrow::len::{name}" 

1077 strlen = default_length 

1078 metadata = schema.metadata if schema.metadata is not None else {} 

1079 if (encoded := md_name.encode("UTF-8")) in metadata: 

1080 # String/bytes length from header. 

1081 strlen = int(schema.metadata[encoded]) 

1082 elif numpy_column is not None and len(numpy_column) > 0: 

1083 strlen = max(len(row) for row in numpy_column) 

1084 

1085 dtype = f"U{strlen}" if schema.field(name).type == pa.string() else f"|S{strlen}" 

1086 

1087 return dtype 

1088 

1089 

1090def _append_numpy_string_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None: 

1091 """Append numpy string length keys to arrow metadata. 

1092 

1093 All column types are handled, but the metadata is only modified for 

1094 string and byte columns. 

1095 

1096 Parameters 

1097 ---------- 

1098 metadata : `dict` [`bytes`, `str`] 

1099 Metadata dictionary; modified in place. 

1100 name : `str` 

1101 Column name. 

1102 dtype : `np.dtype` 

1103 Numpy dtype. 

1104 """ 

1105 import numpy as np 

1106 

1107 if dtype.type is np.str_: 

1108 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize // 4) 

1109 elif dtype.type is np.bytes_: 

1110 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize) 

1111 

1112 

1113def _append_numpy_multidim_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None: 

1114 """Append numpy multi-dimensional shapes to arrow metadata. 

1115 

1116 All column types are handled, but the metadata is only modified for 

1117 multi-dimensional columns. 

1118 

1119 Parameters 

1120 ---------- 

1121 metadata : `dict` [`bytes`, `str`] 

1122 Metadata dictionary; modified in place. 

1123 name : `str` 

1124 Column name. 

1125 dtype : `np.dtype` 

1126 Numpy dtype. 

1127 """ 

1128 if len(dtype.shape) > 1: 

1129 metadata[f"lsst::arrow::shape::{name}".encode()] = str(dtype.shape) 

1130 

1131 

1132def _multidim_shape_from_metadata(metadata: dict[bytes, bytes], list_size: int, name: str) -> tuple[int, ...]: 

1133 """Retrieve the shape from the metadata, if available. 

1134 

1135 Parameters 

1136 ---------- 

1137 metadata : `dict` [`bytes`, `bytes`] 

1138 Metadata dictionary. 

1139 list_size : `int` 

1140 Size of the list datatype. 

1141 name : `str` 

1142 Column name. 

1143 

1144 Returns 

1145 ------- 

1146 shape : `tuple` [`int`] 

1147 Shape associated with the column. 

1148 

1149 Raises 

1150 ------ 

1151 RuntimeError 

1152 Raised if metadata is found but has incorrect format. 

1153 """ 

1154 md_name = f"lsst::arrow::shape::{name}" 

1155 if (encoded := md_name.encode("UTF-8")) in metadata: 

1156 groups = re.search(r"\((.*)\)", metadata[encoded].decode("UTF-8")) 

1157 if groups is None: 

1158 raise RuntimeError("Illegal value found in metadata.") 

1159 shape = tuple(int(x) for x in groups[1].split(",") if x != "") 

1160 else: 

1161 shape = (list_size,) 

1162 

1163 return shape 

1164 

1165 

1166def _schema_to_dtype_list(schema: pa.Schema) -> list[tuple[str, tuple[Any] | str]]: 

1167 """Convert a pyarrow schema to a numpy dtype. 

1168 

1169 Parameters 

1170 ---------- 

1171 schema : `pyarrow.Schema` 

1172 Input pyarrow schema. 

1173 

1174 Returns 

1175 ------- 

1176 dtype_list: `list` [`tuple`] 

1177 A list with name, type pairs. 

1178 """ 

1179 metadata = schema.metadata if schema.metadata is not None else {} 

1180 

1181 dtype: list[Any] = [] 

1182 for name in schema.names: 

1183 t = schema.field(name).type 

1184 if isinstance(t, pa.FixedSizeListType): 

1185 shape = _multidim_shape_from_metadata(metadata, t.list_size, name) 

1186 dtype.append((name, (t.value_type.to_pandas_dtype(), shape))) 

1187 elif t not in (pa.string(), pa.binary()): 

1188 dtype.append((name, t.to_pandas_dtype())) 

1189 else: 

1190 dtype.append((name, _arrow_string_to_numpy_dtype(schema, name))) 

1191 

1192 return dtype 

1193 

1194 

1195def _numpy_dtype_to_arrow_types(dtype: np.dtype) -> list[Any]: 

1196 """Convert a numpy dtype to a list of arrow types. 

1197 

1198 Parameters 

1199 ---------- 

1200 dtype : `numpy.dtype` 

1201 Numpy dtype to convert. 

1202 

1203 Returns 

1204 ------- 

1205 type_list : `list` [`object`] 

1206 Converted list of arrow types. 

1207 """ 

1208 from math import prod 

1209 

1210 import numpy as np 

1211 

1212 type_list: list[Any] = [] 

1213 if dtype.names is None: 

1214 return type_list 

1215 

1216 for name in dtype.names: 

1217 dt = dtype[name] 

1218 arrow_type: Any 

1219 if len(dt.shape) > 0: 

1220 arrow_type = pa.list_( 

1221 pa.from_numpy_dtype(cast(tuple[np.dtype, tuple[int, ...]], dt.subdtype)[0].type), 

1222 prod(dt.shape), 

1223 ) 

1224 else: 

1225 arrow_type = pa.from_numpy_dtype(dt.type) 

1226 type_list.append((name, arrow_type)) 

1227 

1228 return type_list 

1229 

1230 

1231def _numpy_dict_to_dtype(numpy_dict: dict[str, np.ndarray]) -> tuple[np.dtype, int]: 

1232 """Extract equivalent table dtype from dict of numpy arrays. 

1233 

1234 Parameters 

1235 ---------- 

1236 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

1237 Dict with keys as the column names, values as the arrays. 

1238 

1239 Returns 

1240 ------- 

1241 dtype : `numpy.dtype` 

1242 dtype of equivalent table. 

1243 rowcount : `int` 

1244 Number of rows in the table. 

1245 

1246 Raises 

1247 ------ 

1248 ValueError if columns in numpy_dict have unequal numbers of rows. 

1249 """ 

1250 import numpy as np 

1251 

1252 dtype_list = [] 

1253 rowcount = 0 

1254 for name, col in numpy_dict.items(): 

1255 if rowcount == 0: 

1256 rowcount = len(col) 

1257 if len(col) != rowcount: 

1258 raise ValueError(f"Column {name} has a different number of rows.") 

1259 if len(col.shape) == 1: 

1260 dtype_list.append((name, col.dtype)) 

1261 else: 

1262 dtype_list.append((name, (col.dtype, col.shape[1:]))) 

1263 dtype = np.dtype(dtype_list) 

1264 

1265 return (dtype, rowcount) 

1266 

1267 

1268def _numpy_style_arrays_to_arrow_arrays( 

1269 dtype: np.dtype, 

1270 rowcount: int, 

1271 np_style_arrays: dict[str, np.ndarray] | np.ndarray | atable.Table, 

1272 schema: pa.Schema, 

1273) -> list[pa.Array]: 

1274 """Convert numpy-style arrays to arrow arrays. 

1275 

1276 Parameters 

1277 ---------- 

1278 dtype : `numpy.dtype` 

1279 Numpy dtype of input table/arrays. 

1280 rowcount : `int` 

1281 Number of rows in input table/arrays. 

1282 np_style_arrays : `dict` [`str`, `np.ndarray`] or `np.ndarray` 

1283 or `astropy.table.Table` 

1284 Arrays to convert to arrow. 

1285 schema : `pyarrow.Schema` 

1286 Schema of arrow table. 

1287 

1288 Returns 

1289 ------- 

1290 arrow_arrays : `list` [`pyarrow.Array`] 

1291 List of converted pyarrow arrays. 

1292 """ 

1293 import numpy as np 

1294 

1295 arrow_arrays: list[pa.Array] = [] 

1296 if dtype.names is None: 

1297 return arrow_arrays 

1298 

1299 for name in dtype.names: 

1300 dt = dtype[name] 

1301 val: Any 

1302 if len(dt.shape) > 0: 

1303 if rowcount > 0: 

1304 val = np.split(np_style_arrays[name].ravel(), rowcount) 

1305 else: 

1306 val = [] 

1307 else: 

1308 val = np_style_arrays[name] 

1309 

1310 try: 

1311 arrow_arrays.append(pa.array(val, type=schema.field(name).type)) 

1312 except pa.ArrowNotImplementedError as err: 

1313 # Check if val is big-endian. 

1314 if (np.little_endian and val.dtype.byteorder == ">") or ( 

1315 not np.little_endian and val.dtype.byteorder == "=" 

1316 ): 

1317 # We need to convert the array to little-endian. 

1318 val2 = val.byteswap() 

1319 val2.dtype = val2.dtype.newbyteorder("<") 

1320 arrow_arrays.append(pa.array(val2, type=schema.field(name).type)) 

1321 else: 

1322 # This failed for some other reason so raise the exception. 

1323 raise err 

1324 

1325 return arrow_arrays 

1326 

1327 

1328def compute_row_group_size(schema: pa.Schema, target_size: int = TARGET_ROW_GROUP_BYTES) -> int: 

1329 """Compute approximate row group size for a given arrow schema. 

1330 

1331 Given a schema, this routine will compute the number of rows in a row group 

1332 that targets the persisted size on disk (or smaller). The exact size on 

1333 disk depends on the compression settings and ratios; typical binary data 

1334 tables will have around 15-20% compression with the pyarrow default 

1335 ``snappy`` compression algorithm. 

1336 

1337 Parameters 

1338 ---------- 

1339 schema : `pyarrow.Schema` 

1340 Arrow table schema. 

1341 target_size : `int`, optional 

1342 The target size (in bytes). 

1343 

1344 Returns 

1345 ------- 

1346 row_group_size : `int` 

1347 Number of rows per row group to hit the target size. 

1348 """ 

1349 bit_width = 0 

1350 

1351 metadata = schema.metadata if schema.metadata is not None else {} 

1352 

1353 for name in schema.names: 

1354 t = schema.field(name).type 

1355 

1356 if t in (pa.string(), pa.binary()): 

1357 md_name = f"lsst::arrow::len::{name}" 

1358 

1359 if (encoded := md_name.encode("UTF-8")) in metadata: 

1360 # String/bytes length from header. 

1361 strlen = int(schema.metadata[encoded]) 

1362 else: 

1363 # We don't know the string width, so guess something. 

1364 strlen = 10 

1365 

1366 # Assuming UTF-8 encoding, and very few wide characters. 

1367 t_width = 8 * strlen 

1368 elif isinstance(t, pa.FixedSizeListType): 

1369 if t.value_type == pa.null(): 

1370 t_width = 0 

1371 else: 

1372 t_width = t.list_size * t.value_type.bit_width 

1373 elif t == pa.null(): 

1374 t_width = 0 

1375 elif isinstance(t, pa.ListType): 

1376 if t.value_type == pa.null(): 

1377 t_width = 0 

1378 else: 

1379 # This is a variable length list, just choose 

1380 # something arbitrary. 

1381 t_width = 10 * t.value_type.bit_width 

1382 else: 

1383 t_width = t.bit_width 

1384 

1385 bit_width += t_width 

1386 

1387 # Insist it is at least 1 byte wide to avoid any divide-by-zero errors. 

1388 if bit_width < 8: 

1389 bit_width = 8 

1390 

1391 byte_width = bit_width // 8 

1392 

1393 return target_size // byte_width