Coverage for python/lsst/daf/butler/formatters/parquet.py: 13%

469 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-15 02:03 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ( 

31 "ParquetFormatter", 

32 "arrow_to_pandas", 

33 "arrow_to_astropy", 

34 "arrow_to_numpy", 

35 "arrow_to_numpy_dict", 

36 "pandas_to_arrow", 

37 "pandas_to_astropy", 

38 "astropy_to_arrow", 

39 "numpy_to_arrow", 

40 "numpy_to_astropy", 

41 "numpy_dict_to_arrow", 

42 "arrow_schema_to_pandas_index", 

43 "DataFrameSchema", 

44 "ArrowAstropySchema", 

45 "ArrowNumpySchema", 

46 "compute_row_group_size", 

47) 

48 

49import collections.abc 

50import itertools 

51import json 

52import re 

53from collections.abc import Iterable, Sequence 

54from typing import TYPE_CHECKING, Any, cast 

55 

56import pyarrow as pa 

57import pyarrow.parquet as pq 

58from lsst.daf.butler import Formatter 

59from lsst.utils.introspection import get_full_type_name 

60from lsst.utils.iteration import ensure_iterable 

61 

62if TYPE_CHECKING: 

63 import astropy.table as atable 

64 import numpy as np 

65 import pandas as pd 

66 

67TARGET_ROW_GROUP_BYTES = 1_000_000_000 

68 

69 

70class ParquetFormatter(Formatter): 

71 """Interface for reading and writing Arrow Table objects to and from 

72 Parquet files. 

73 """ 

74 

75 extension = ".parq" 

76 

77 def read(self, component: str | None = None) -> Any: 

78 # Docstring inherited from Formatter.read. 

79 schema = pq.read_schema(self.fileDescriptor.location.path) 

80 

81 schema_names = ["ArrowSchema", "DataFrameSchema", "ArrowAstropySchema", "ArrowNumpySchema"] 

82 

83 if component in ("columns", "schema") or self.fileDescriptor.readStorageClass.name in schema_names: 

84 # The schema will be translated to column format 

85 # depending on the input type. 

86 return schema 

87 elif component == "rowcount": 

88 # Get the rowcount from the metadata if possible, otherwise count. 

89 if b"lsst::arrow::rowcount" in schema.metadata: 

90 return int(schema.metadata[b"lsst::arrow::rowcount"]) 

91 

92 temp_table = pq.read_table( 

93 self.fileDescriptor.location.path, 

94 columns=[schema.names[0]], 

95 use_threads=False, 

96 use_pandas_metadata=False, 

97 ) 

98 

99 return len(temp_table[schema.names[0]]) 

100 

101 par_columns = None 

102 if self.fileDescriptor.parameters: 

103 par_columns = self.fileDescriptor.parameters.pop("columns", None) 

104 if par_columns: 

105 has_pandas_multi_index = False 

106 if b"pandas" in schema.metadata: 

107 md = json.loads(schema.metadata[b"pandas"]) 

108 if len(md["column_indexes"]) > 1: 

109 has_pandas_multi_index = True 

110 

111 if not has_pandas_multi_index: 

112 # Ensure uniqueness, keeping order. 

113 par_columns = list(dict.fromkeys(ensure_iterable(par_columns))) 

114 file_columns = [name for name in schema.names if not name.startswith("__")] 

115 

116 for par_column in par_columns: 

117 if par_column not in file_columns: 

118 raise ValueError( 

119 f"Column {par_column} specified in parameters not available in parquet file." 

120 ) 

121 else: 

122 par_columns = _standardize_multi_index_columns( 

123 arrow_schema_to_pandas_index(schema), 

124 par_columns, 

125 ) 

126 

127 if len(self.fileDescriptor.parameters): 

128 raise ValueError( 

129 f"Unsupported parameters {self.fileDescriptor.parameters} in ArrowTable read." 

130 ) 

131 

132 metadata = schema.metadata if schema.metadata is not None else {} 

133 arrow_table = pq.read_table( 

134 self.fileDescriptor.location.path, 

135 columns=par_columns, 

136 use_threads=False, 

137 use_pandas_metadata=(b"pandas" in metadata), 

138 ) 

139 

140 return arrow_table 

141 

142 def write(self, inMemoryDataset: Any) -> None: 

143 import numpy as np 

144 from astropy.table import Table as astropyTable 

145 

146 location = self.makeUpdatedLocation(self.fileDescriptor.location) 

147 

148 arrow_table = None 

149 if isinstance(inMemoryDataset, pa.Table): 

150 # This will be the most likely match. 

151 arrow_table = inMemoryDataset 

152 elif isinstance(inMemoryDataset, astropyTable): 

153 arrow_table = astropy_to_arrow(inMemoryDataset) 

154 elif isinstance(inMemoryDataset, np.ndarray): 

155 arrow_table = numpy_to_arrow(inMemoryDataset) 

156 elif isinstance(inMemoryDataset, dict): 

157 try: 

158 arrow_table = numpy_dict_to_arrow(inMemoryDataset) 

159 except (TypeError, AttributeError) as e: 

160 raise ValueError( 

161 "Input dict for inMemoryDataset does not appear to be a dict of numpy arrays." 

162 ) from e 

163 elif isinstance(inMemoryDataset, pa.Schema): 

164 pq.write_metadata(inMemoryDataset, location.path) 

165 return 

166 else: 

167 if hasattr(inMemoryDataset, "to_parquet"): 

168 # This may be a pandas DataFrame 

169 try: 

170 import pandas as pd 

171 except ImportError: 

172 pd = None 

173 

174 if pd is not None and isinstance(inMemoryDataset, pd.DataFrame): 

175 arrow_table = pandas_to_arrow(inMemoryDataset) 

176 

177 if arrow_table is None: 

178 raise ValueError( 

179 f"Unsupported type {get_full_type_name(inMemoryDataset)} of " 

180 "inMemoryDataset for ParquetFormatter." 

181 ) 

182 

183 row_group_size = compute_row_group_size(arrow_table.schema) 

184 

185 pq.write_table(arrow_table, location.path, row_group_size=row_group_size) 

186 

187 

188def arrow_to_pandas(arrow_table: pa.Table) -> pd.DataFrame: 

189 """Convert a pyarrow table to a pandas DataFrame. 

190 

191 Parameters 

192 ---------- 

193 arrow_table : `pyarrow.Table` 

194 Input arrow table to convert. If the table has ``pandas`` metadata 

195 in the schema it will be used in the construction of the 

196 ``DataFrame``. 

197 

198 Returns 

199 ------- 

200 dataframe : `pandas.DataFrame` 

201 Converted pandas dataframe. 

202 """ 

203 return arrow_table.to_pandas(use_threads=False, integer_object_nulls=True) 

204 

205 

206def arrow_to_astropy(arrow_table: pa.Table) -> atable.Table: 

207 """Convert a pyarrow table to an `astropy.Table`. 

208 

209 Parameters 

210 ---------- 

211 arrow_table : `pyarrow.Table` 

212 Input arrow table to convert. If the table has astropy unit 

213 metadata in the schema it will be used in the construction 

214 of the ``astropy.Table``. 

215 

216 Returns 

217 ------- 

218 table : `astropy.Table` 

219 Converted astropy table. 

220 """ 

221 from astropy.table import Table 

222 

223 astropy_table = Table(arrow_to_numpy_dict(arrow_table)) 

224 

225 _apply_astropy_metadata(astropy_table, arrow_table.schema) 

226 

227 return astropy_table 

228 

229 

230def arrow_to_numpy(arrow_table: pa.Table) -> np.ndarray: 

231 """Convert a pyarrow table to a structured numpy array. 

232 

233 Parameters 

234 ---------- 

235 arrow_table : `pyarrow.Table` 

236 Input arrow table. 

237 

238 Returns 

239 ------- 

240 array : `numpy.ndarray` (N,) 

241 Numpy array table with N rows and the same column names 

242 as the input arrow table. 

243 """ 

244 import numpy as np 

245 

246 numpy_dict = arrow_to_numpy_dict(arrow_table) 

247 

248 dtype = [] 

249 for name, col in numpy_dict.items(): 

250 if len(shape := numpy_dict[name].shape) <= 1: 

251 dtype.append((name, col.dtype)) 

252 else: 

253 dtype.append((name, (col.dtype, shape[1:]))) 

254 

255 array = np.rec.fromarrays(numpy_dict.values(), dtype=dtype) 

256 

257 return array 

258 

259 

260def arrow_to_numpy_dict(arrow_table: pa.Table) -> dict[str, np.ndarray]: 

261 """Convert a pyarrow table to a dict of numpy arrays. 

262 

263 Parameters 

264 ---------- 

265 arrow_table : `pyarrow.Table` 

266 Input arrow table. 

267 

268 Returns 

269 ------- 

270 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

271 Dict with keys as the column names, values as the arrays. 

272 """ 

273 import numpy as np 

274 

275 schema = arrow_table.schema 

276 metadata = schema.metadata if schema.metadata is not None else {} 

277 

278 numpy_dict = {} 

279 

280 for name in schema.names: 

281 t = schema.field(name).type 

282 

283 if arrow_table[name].null_count == 0: 

284 # Regular non-masked column 

285 col = arrow_table[name].to_numpy() 

286 else: 

287 # For a masked column, we need to ask arrow to fill the null 

288 # values with an appropriately typed value before conversion. 

289 # Then we apply the mask to get a masked array of the correct type. 

290 null_value: Any 

291 match t: 

292 case t if t in (pa.float64(), pa.float32(), pa.float16()): 

293 null_value = np.nan 

294 case t if t in (pa.int64(), pa.int32(), pa.int16(), pa.int8()): 

295 null_value = -1 

296 case t if t in (pa.bool_(),): 

297 null_value = True 

298 case t if t in (pa.string(), pa.binary()): 

299 null_value = "" 

300 case _: 

301 # This is the fallback for unsigned ints in particular. 

302 null_value = 0 

303 

304 col = np.ma.masked_array( 

305 data=arrow_table[name].fill_null(null_value).to_numpy(), 

306 mask=arrow_table[name].is_null().to_numpy(), 

307 fill_value=null_value, 

308 ) 

309 

310 if t in (pa.string(), pa.binary()): 

311 col = col.astype(_arrow_string_to_numpy_dtype(schema, name, col)) 

312 elif isinstance(t, pa.FixedSizeListType): 

313 if len(col) > 0: 

314 col = np.stack(col) 

315 else: 

316 # this is an empty column, and needs to be coerced to type. 

317 col = col.astype(t.value_type.to_pandas_dtype()) 

318 

319 shape = _multidim_shape_from_metadata(metadata, t.list_size, name) 

320 col = col.reshape((len(arrow_table), *shape)) 

321 

322 numpy_dict[name] = col 

323 

324 return numpy_dict 

325 

326 

327def _numpy_dict_to_numpy(numpy_dict: dict[str, np.ndarray]) -> np.ndarray: 

328 """Convert a dict of numpy arrays to a structured numpy array. 

329 

330 Parameters 

331 ---------- 

332 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

333 Dict with keys as the column names, values as the arrays. 

334 

335 Returns 

336 ------- 

337 array : `numpy.ndarray` (N,) 

338 Numpy array table with N rows and columns names from the dict keys. 

339 """ 

340 return arrow_to_numpy(numpy_dict_to_arrow(numpy_dict)) 

341 

342 

343def _numpy_to_numpy_dict(np_array: np.ndarray) -> dict[str, np.ndarray]: 

344 """Convert a structured numpy array to a dict of numpy arrays. 

345 

346 Parameters 

347 ---------- 

348 np_array : `numpy.ndarray` 

349 Input numpy array with multiple fields. 

350 

351 Returns 

352 ------- 

353 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

354 Dict with keys as the column names, values as the arrays. 

355 """ 

356 return arrow_to_numpy_dict(numpy_to_arrow(np_array)) 

357 

358 

359def numpy_to_arrow(np_array: np.ndarray) -> pa.Table: 

360 """Convert a numpy array table to an arrow table. 

361 

362 Parameters 

363 ---------- 

364 np_array : `numpy.ndarray` 

365 Input numpy array with multiple fields. 

366 

367 Returns 

368 ------- 

369 arrow_table : `pyarrow.Table` 

370 Converted arrow table. 

371 """ 

372 type_list = _numpy_dtype_to_arrow_types(np_array.dtype) 

373 

374 md = {} 

375 md[b"lsst::arrow::rowcount"] = str(len(np_array)) 

376 

377 for name in np_array.dtype.names: 

378 _append_numpy_string_metadata(md, name, np_array.dtype[name]) 

379 _append_numpy_multidim_metadata(md, name, np_array.dtype[name]) 

380 

381 schema = pa.schema(type_list, metadata=md) 

382 

383 arrays = _numpy_style_arrays_to_arrow_arrays( 

384 np_array.dtype, 

385 len(np_array), 

386 np_array, 

387 schema, 

388 ) 

389 

390 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

391 

392 return arrow_table 

393 

394 

395def numpy_dict_to_arrow(numpy_dict: dict[str, np.ndarray]) -> pa.Table: 

396 """Convert a dict of numpy arrays to an arrow table. 

397 

398 Parameters 

399 ---------- 

400 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

401 Dict with keys as the column names, values as the arrays. 

402 

403 Returns 

404 ------- 

405 arrow_table : `pyarrow.Table` 

406 Converted arrow table. 

407 

408 Raises 

409 ------ 

410 ValueError if columns in numpy_dict have unequal numbers of rows. 

411 """ 

412 dtype, rowcount = _numpy_dict_to_dtype(numpy_dict) 

413 type_list = _numpy_dtype_to_arrow_types(dtype) 

414 

415 md = {} 

416 md[b"lsst::arrow::rowcount"] = str(rowcount) 

417 

418 if dtype.names is not None: 

419 for name in dtype.names: 

420 _append_numpy_string_metadata(md, name, dtype[name]) 

421 _append_numpy_multidim_metadata(md, name, dtype[name]) 

422 

423 schema = pa.schema(type_list, metadata=md) 

424 

425 arrays = _numpy_style_arrays_to_arrow_arrays( 

426 dtype, 

427 rowcount, 

428 numpy_dict, 

429 schema, 

430 ) 

431 

432 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

433 

434 return arrow_table 

435 

436 

437def astropy_to_arrow(astropy_table: atable.Table) -> pa.Table: 

438 """Convert an astropy table to an arrow table. 

439 

440 Parameters 

441 ---------- 

442 astropy_table : `astropy.Table` 

443 Input astropy table. 

444 

445 Returns 

446 ------- 

447 arrow_table : `pyarrow.Table` 

448 Converted arrow table. 

449 """ 

450 from astropy.table import meta 

451 

452 type_list = _numpy_dtype_to_arrow_types(astropy_table.dtype) 

453 

454 md = {} 

455 md[b"lsst::arrow::rowcount"] = str(len(astropy_table)) 

456 

457 for name in astropy_table.dtype.names: 

458 _append_numpy_string_metadata(md, name, astropy_table.dtype[name]) 

459 _append_numpy_multidim_metadata(md, name, astropy_table.dtype[name]) 

460 

461 meta_yaml = meta.get_yaml_from_table(astropy_table) 

462 meta_yaml_str = "\n".join(meta_yaml) 

463 md[b"table_meta_yaml"] = meta_yaml_str 

464 

465 # Convert type list to fields with metadata. 

466 fields = [] 

467 for name, pa_type in type_list: 

468 field_metadata = {} 

469 if description := astropy_table[name].description: 

470 field_metadata["description"] = description 

471 if unit := astropy_table[name].unit: 

472 field_metadata["unit"] = str(unit) 

473 fields.append( 

474 pa.field( 

475 name, 

476 pa_type, 

477 metadata=field_metadata, 

478 ) 

479 ) 

480 

481 schema = pa.schema(fields, metadata=md) 

482 

483 arrays = _numpy_style_arrays_to_arrow_arrays( 

484 astropy_table.dtype, 

485 len(astropy_table), 

486 astropy_table, 

487 schema, 

488 ) 

489 

490 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

491 

492 return arrow_table 

493 

494 

495def _astropy_to_numpy_dict(astropy_table: atable.Table) -> dict[str, np.ndarray]: 

496 """Convert an astropy table to an arrow table. 

497 

498 Parameters 

499 ---------- 

500 astropy_table : `astropy.Table` 

501 Input astropy table. 

502 

503 Returns 

504 ------- 

505 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

506 Dict with keys as the column names, values as the arrays. 

507 """ 

508 return arrow_to_numpy_dict(astropy_to_arrow(astropy_table)) 

509 

510 

511def pandas_to_arrow(dataframe: pd.DataFrame, default_length: int = 10) -> pa.Table: 

512 """Convert a pandas dataframe to an arrow table. 

513 

514 Parameters 

515 ---------- 

516 dataframe : `pandas.DataFrame` 

517 Input pandas dataframe. 

518 default_length : `int`, optional 

519 Default string length when not in metadata or can be inferred 

520 from column. 

521 

522 Returns 

523 ------- 

524 arrow_table : `pyarrow.Table` 

525 Converted arrow table. 

526 """ 

527 arrow_table = pa.Table.from_pandas(dataframe) 

528 

529 # Update the metadata 

530 md = arrow_table.schema.metadata 

531 

532 md[b"lsst::arrow::rowcount"] = str(arrow_table.num_rows) 

533 

534 # We loop through the arrow table columns because the datatypes have 

535 # been checked and converted from pandas objects. 

536 for name in arrow_table.column_names: 

537 if not name.startswith("__") and arrow_table[name].type == pa.string(): 

538 if len(arrow_table[name]) > 0: 

539 strlen = max(len(row.as_py()) for row in arrow_table[name] if row.is_valid) 

540 else: 

541 strlen = default_length 

542 md[f"lsst::arrow::len::{name}".encode()] = str(strlen) 

543 

544 arrow_table = arrow_table.replace_schema_metadata(md) 

545 

546 return arrow_table 

547 

548 

549def pandas_to_astropy(dataframe: pd.DataFrame) -> atable.Table: 

550 """Convert a pandas dataframe to an astropy table, preserving indexes. 

551 

552 Parameters 

553 ---------- 

554 dataframe : `pandas.DataFrame` 

555 Input pandas dataframe. 

556 

557 Returns 

558 ------- 

559 astropy_table : `astropy.table.Table` 

560 Converted astropy table. 

561 """ 

562 import pandas as pd 

563 from astropy.table import Table 

564 

565 if isinstance(dataframe.columns, pd.MultiIndex): 

566 raise ValueError("Cannot convert a multi-index dataframe to an astropy table.") 

567 

568 return Table.from_pandas(dataframe, index=True) 

569 

570 

571def _pandas_to_numpy_dict(dataframe: pd.DataFrame) -> dict[str, np.ndarray]: 

572 """Convert a pandas dataframe to an dict of numpy arrays. 

573 

574 Parameters 

575 ---------- 

576 dataframe : `pandas.DataFrame` 

577 Input pandas dataframe. 

578 

579 Returns 

580 ------- 

581 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

582 Dict with keys as the column names, values as the arrays. 

583 """ 

584 return arrow_to_numpy_dict(pandas_to_arrow(dataframe)) 

585 

586 

587def numpy_to_astropy(np_array: np.ndarray) -> atable.Table: 

588 """Convert a numpy table to an astropy table. 

589 

590 Parameters 

591 ---------- 

592 np_array : `numpy.ndarray` 

593 Input numpy array with multiple fields. 

594 

595 Returns 

596 ------- 

597 astropy_table : `astropy.table.Table` 

598 Converted astropy table. 

599 """ 

600 from astropy.table import Table 

601 

602 return Table(data=np_array, copy=False) 

603 

604 

605def arrow_schema_to_pandas_index(schema: pa.Schema) -> pd.Index | pd.MultiIndex: 

606 """Convert an arrow schema to a pandas index/multiindex. 

607 

608 Parameters 

609 ---------- 

610 schema : `pyarrow.Schema` 

611 Input pyarrow schema. 

612 

613 Returns 

614 ------- 

615 index : `pandas.Index` or `pandas.MultiIndex` 

616 Converted pandas index. 

617 """ 

618 import pandas as pd 

619 

620 if b"pandas" in schema.metadata: 

621 md = json.loads(schema.metadata[b"pandas"]) 

622 indexes = md["column_indexes"] 

623 len_indexes = len(indexes) 

624 else: 

625 len_indexes = 0 

626 

627 if len_indexes <= 1: 

628 return pd.Index(name for name in schema.names if not name.startswith("__")) 

629 else: 

630 raw_columns = _split_multi_index_column_names(len(indexes), schema.names) 

631 return pd.MultiIndex.from_tuples(raw_columns, names=[f["name"] for f in indexes]) 

632 

633 

634def arrow_schema_to_column_list(schema: pa.Schema) -> list[str]: 

635 """Convert an arrow schema to a list of string column names. 

636 

637 Parameters 

638 ---------- 

639 schema : `pyarrow.Schema` 

640 Input pyarrow schema. 

641 

642 Returns 

643 ------- 

644 column_list : `list` [`str`] 

645 Converted list of column names. 

646 """ 

647 return list(schema.names) 

648 

649 

650class DataFrameSchema: 

651 """Wrapper class for a schema for a pandas DataFrame. 

652 

653 Parameters 

654 ---------- 

655 dataframe : `pandas.DataFrame` 

656 Dataframe to turn into a schema. 

657 """ 

658 

659 def __init__(self, dataframe: pd.DataFrame) -> None: 

660 self._schema = dataframe.loc[[False] * len(dataframe)] 

661 

662 @classmethod 

663 def from_arrow(cls, schema: pa.Schema) -> DataFrameSchema: 

664 """Convert an arrow schema into a `DataFrameSchema`. 

665 

666 Parameters 

667 ---------- 

668 schema : `pyarrow.Schema` 

669 The pyarrow schema to convert. 

670 

671 Returns 

672 ------- 

673 dataframe_schema : `DataFrameSchema` 

674 Converted dataframe schema. 

675 """ 

676 empty_table = pa.Table.from_pylist([] * len(schema.names), schema=schema) 

677 

678 return cls(empty_table.to_pandas()) 

679 

680 def to_arrow_schema(self) -> pa.Schema: 

681 """Convert to an arrow schema. 

682 

683 Returns 

684 ------- 

685 arrow_schema : `pyarrow.Schema` 

686 Converted pyarrow schema. 

687 """ 

688 arrow_table = pa.Table.from_pandas(self._schema) 

689 

690 return arrow_table.schema 

691 

692 def to_arrow_numpy_schema(self) -> ArrowNumpySchema: 

693 """Convert to an `ArrowNumpySchema`. 

694 

695 Returns 

696 ------- 

697 arrow_numpy_schema : `ArrowNumpySchema` 

698 Converted arrow numpy schema. 

699 """ 

700 return ArrowNumpySchema.from_arrow(self.to_arrow_schema()) 

701 

702 def to_arrow_astropy_schema(self) -> ArrowAstropySchema: 

703 """Convert to an ArrowAstropySchema. 

704 

705 Returns 

706 ------- 

707 arrow_astropy_schema : `ArrowAstropySchema` 

708 Converted arrow astropy schema. 

709 """ 

710 return ArrowAstropySchema.from_arrow(self.to_arrow_schema()) 

711 

712 @property 

713 def schema(self) -> np.dtype: 

714 return self._schema 

715 

716 def __repr__(self) -> str: 

717 return repr(self._schema) 

718 

719 def __eq__(self, other: object) -> bool: 

720 if not isinstance(other, DataFrameSchema): 

721 return NotImplemented 

722 

723 return self._schema.equals(other._schema) 

724 

725 

726class ArrowAstropySchema: 

727 """Wrapper class for a schema for an astropy table. 

728 

729 Parameters 

730 ---------- 

731 astropy_table : `astropy.table.Table` 

732 Input astropy table. 

733 """ 

734 

735 def __init__(self, astropy_table: atable.Table) -> None: 

736 self._schema = astropy_table[:0] 

737 

738 @classmethod 

739 def from_arrow(cls, schema: pa.Schema) -> ArrowAstropySchema: 

740 """Convert an arrow schema into a ArrowAstropySchema. 

741 

742 Parameters 

743 ---------- 

744 schema : `pyarrow.Schema` 

745 Input pyarrow schema. 

746 

747 Returns 

748 ------- 

749 astropy_schema : `ArrowAstropySchema` 

750 Converted arrow astropy schema. 

751 """ 

752 import numpy as np 

753 from astropy.table import Table 

754 

755 dtype = _schema_to_dtype_list(schema) 

756 

757 data = np.zeros(0, dtype=dtype) 

758 astropy_table = Table(data=data) 

759 

760 _apply_astropy_metadata(astropy_table, schema) 

761 

762 return cls(astropy_table) 

763 

764 def to_arrow_schema(self) -> pa.Schema: 

765 """Convert to an arrow schema. 

766 

767 Returns 

768 ------- 

769 arrow_schema : `pyarrow.Schema` 

770 Converted pyarrow schema. 

771 """ 

772 return astropy_to_arrow(self._schema).schema 

773 

774 def to_dataframe_schema(self) -> DataFrameSchema: 

775 """Convert to a DataFrameSchema. 

776 

777 Returns 

778 ------- 

779 dataframe_schema : `DataFrameSchema` 

780 Converted dataframe schema. 

781 """ 

782 return DataFrameSchema.from_arrow(astropy_to_arrow(self._schema).schema) 

783 

784 def to_arrow_numpy_schema(self) -> ArrowNumpySchema: 

785 """Convert to an `ArrowNumpySchema`. 

786 

787 Returns 

788 ------- 

789 arrow_numpy_schema : `ArrowNumpySchema` 

790 Converted arrow numpy schema. 

791 """ 

792 return ArrowNumpySchema.from_arrow(astropy_to_arrow(self._schema).schema) 

793 

794 @property 

795 def schema(self) -> atable.Table: 

796 return self._schema 

797 

798 def __repr__(self) -> str: 

799 return repr(self._schema) 

800 

801 def __eq__(self, other: object) -> bool: 

802 if not isinstance(other, ArrowAstropySchema): 

803 return NotImplemented 

804 

805 # If this comparison passes then the two tables have the 

806 # same column names. 

807 if self._schema.dtype != other._schema.dtype: 

808 return False 

809 

810 for name in self._schema.columns: 

811 if not self._schema[name].unit == other._schema[name].unit: 

812 return False 

813 if not self._schema[name].description == other._schema[name].description: 

814 return False 

815 if not self._schema[name].format == other._schema[name].format: 

816 return False 

817 

818 return True 

819 

820 

821class ArrowNumpySchema: 

822 """Wrapper class for a schema for a numpy ndarray. 

823 

824 Parameters 

825 ---------- 

826 numpy_dtype : `numpy.dtype` 

827 Numpy dtype to convert. 

828 """ 

829 

830 def __init__(self, numpy_dtype: np.dtype) -> None: 

831 self._dtype = numpy_dtype 

832 

833 @classmethod 

834 def from_arrow(cls, schema: pa.Schema) -> ArrowNumpySchema: 

835 """Convert an arrow schema into an `ArrowNumpySchema`. 

836 

837 Parameters 

838 ---------- 

839 schema : `pyarrow.Schema` 

840 Pyarrow schema to convert. 

841 

842 Returns 

843 ------- 

844 numpy_schema : `ArrowNumpySchema` 

845 Converted arrow numpy schema. 

846 """ 

847 import numpy as np 

848 

849 dtype = _schema_to_dtype_list(schema) 

850 

851 return cls(np.dtype(dtype)) 

852 

853 def to_arrow_astropy_schema(self) -> ArrowAstropySchema: 

854 """Convert to an `ArrowAstropySchema`. 

855 

856 Returns 

857 ------- 

858 astropy_schema : `ArrowAstropySchema` 

859 Converted arrow astropy schema. 

860 """ 

861 import numpy as np 

862 

863 return ArrowAstropySchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema) 

864 

865 def to_dataframe_schema(self) -> DataFrameSchema: 

866 """Convert to a `DataFrameSchema`. 

867 

868 Returns 

869 ------- 

870 dataframe_schema : `DataFrameSchema` 

871 Converted dataframe schema. 

872 """ 

873 import numpy as np 

874 

875 return DataFrameSchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema) 

876 

877 def to_arrow_schema(self) -> pa.Schema: 

878 """Convert to a `pyarrow.Schema`. 

879 

880 Returns 

881 ------- 

882 arrow_schema : `pyarrow.Schema` 

883 Converted pyarrow schema. 

884 """ 

885 import numpy as np 

886 

887 return numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema 

888 

889 @property 

890 def schema(self) -> np.dtype: 

891 return self._dtype 

892 

893 def __repr__(self) -> str: 

894 return repr(self._dtype) 

895 

896 def __eq__(self, other: object) -> bool: 

897 if not isinstance(other, ArrowNumpySchema): 

898 return NotImplemented 

899 

900 if not self._dtype == other._dtype: 

901 return False 

902 

903 return True 

904 

905 

906def _split_multi_index_column_names(n: int, names: Iterable[str]) -> list[Sequence[str]]: 

907 """Split a string that represents a multi-index column. 

908 

909 PyArrow maps Pandas' multi-index column names (which are tuples in Python) 

910 to flat strings on disk. This routine exists to reconstruct the original 

911 tuple. 

912 

913 Parameters 

914 ---------- 

915 n : `int` 

916 Number of levels in the `pandas.MultiIndex` that is being 

917 reconstructed. 

918 names : `~collections.abc.Iterable` [`str`] 

919 Strings to be split. 

920 

921 Returns 

922 ------- 

923 column_names : `list` [`tuple` [`str`]] 

924 A list of multi-index column name tuples. 

925 """ 

926 column_names: list[Sequence[str]] = [] 

927 

928 pattern = re.compile(r"\({}\)".format(", ".join(["'(.*)'"] * n))) 

929 for name in names: 

930 m = re.search(pattern, name) 

931 if m is not None: 

932 column_names.append(m.groups()) 

933 

934 return column_names 

935 

936 

937def _standardize_multi_index_columns( 

938 pd_index: pd.MultiIndex, 

939 columns: Any, 

940 stringify: bool = True, 

941) -> list[str | Sequence[Any]]: 

942 """Transform a dictionary/iterable index from a multi-index column list 

943 into a string directly understandable by PyArrow. 

944 

945 Parameters 

946 ---------- 

947 pd_index : `pandas.MultiIndex` 

948 Pandas multi-index. 

949 columns : `list` [`tuple`] or `dict` [`str`, `str` or `list` [`str`]] 

950 Columns to standardize. 

951 stringify : `bool`, optional 

952 Should the column names be stringified? 

953 

954 Returns 

955 ------- 

956 names : `list` [`str`] 

957 Stringified representation of a multi-index column name. 

958 """ 

959 index_level_names = tuple(pd_index.names) 

960 

961 names: list[str | Sequence[Any]] = [] 

962 

963 if isinstance(columns, list): 

964 for requested in columns: 

965 if not isinstance(requested, tuple): 

966 raise ValueError( 

967 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. " 

968 f"Instead got a {get_full_type_name(requested)}." 

969 ) 

970 if stringify: 

971 names.append(str(requested)) 

972 else: 

973 names.append(requested) 

974 else: 

975 if not isinstance(columns, collections.abc.Mapping): 

976 raise ValueError( 

977 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. " 

978 f"Instead got a {get_full_type_name(columns)}." 

979 ) 

980 if not set(index_level_names).issuperset(columns.keys()): 

981 raise ValueError( 

982 f"Cannot use dict with keys {set(columns.keys())} to select columns from {index_level_names}." 

983 ) 

984 factors = [ 

985 ensure_iterable(columns.get(level, pd_index.levels[i])) 

986 for i, level in enumerate(index_level_names) 

987 ] 

988 for requested in itertools.product(*factors): 

989 for i, value in enumerate(requested): 

990 if value not in pd_index.levels[i]: 

991 raise ValueError(f"Unrecognized value {value!r} for index {index_level_names[i]!r}.") 

992 if stringify: 

993 names.append(str(requested)) 

994 else: 

995 names.append(requested) 

996 

997 return names 

998 

999 

1000def _apply_astropy_metadata(astropy_table: atable.Table, arrow_schema: pa.Schema) -> None: 

1001 """Apply any astropy metadata from the schema metadata. 

1002 

1003 Parameters 

1004 ---------- 

1005 astropy_table : `astropy.table.Table` 

1006 Table to apply metadata. 

1007 arrow_schema : `pyarrow.Schema` 

1008 Arrow schema with metadata. 

1009 """ 

1010 from astropy.table import meta 

1011 

1012 metadata = arrow_schema.metadata if arrow_schema.metadata is not None else {} 

1013 

1014 # Check if we have a special astropy metadata header yaml. 

1015 meta_yaml = metadata.get(b"table_meta_yaml", None) 

1016 if meta_yaml: 

1017 meta_yaml = meta_yaml.decode("UTF8").split("\n") 

1018 meta_hdr = meta.get_header_from_yaml(meta_yaml) 

1019 

1020 # Set description, format, unit, meta from the column 

1021 # metadata that was serialized with the table. 

1022 header_cols = {x["name"]: x for x in meta_hdr["datatype"]} 

1023 for col in astropy_table.columns.values(): 

1024 for attr in ("description", "format", "unit", "meta"): 

1025 if attr in header_cols[col.name]: 

1026 setattr(col, attr, header_cols[col.name][attr]) 

1027 

1028 if "meta" in meta_hdr: 

1029 astropy_table.meta.update(meta_hdr["meta"]) 

1030 else: 

1031 # If we don't have astropy header data, we may have arrow field 

1032 # metadata. 

1033 for name in arrow_schema.names: 

1034 field_metadata = arrow_schema.field(name).metadata 

1035 if field_metadata is None: 

1036 continue 

1037 if ( 

1038 b"description" in field_metadata 

1039 and (description := field_metadata[b"description"].decode("UTF-8")) != "" 

1040 ): 

1041 astropy_table[name].description = description 

1042 if b"unit" in field_metadata and (unit := field_metadata[b"unit"].decode("UTF-8")) != "": 

1043 astropy_table[name].unit = unit 

1044 

1045 

1046def _arrow_string_to_numpy_dtype( 

1047 schema: pa.Schema, name: str, numpy_column: np.ndarray | None = None, default_length: int = 10 

1048) -> str: 

1049 """Get the numpy dtype string associated with an arrow column. 

1050 

1051 Parameters 

1052 ---------- 

1053 schema : `pyarrow.Schema` 

1054 Arrow table schema. 

1055 name : `str` 

1056 Column name. 

1057 numpy_column : `numpy.ndarray`, optional 

1058 Column to determine numpy string dtype. 

1059 default_length : `int`, optional 

1060 Default string length when not in metadata or can be inferred 

1061 from column. 

1062 

1063 Returns 

1064 ------- 

1065 dtype_str : `str` 

1066 Numpy dtype string. 

1067 """ 

1068 # Special-case for string and binary columns 

1069 md_name = f"lsst::arrow::len::{name}" 

1070 strlen = default_length 

1071 metadata = schema.metadata if schema.metadata is not None else {} 

1072 if (encoded := md_name.encode("UTF-8")) in metadata: 

1073 # String/bytes length from header. 

1074 strlen = int(schema.metadata[encoded]) 

1075 elif numpy_column is not None and len(numpy_column) > 0: 

1076 strlen = max(len(row) for row in numpy_column) 

1077 

1078 dtype = f"U{strlen}" if schema.field(name).type == pa.string() else f"|S{strlen}" 

1079 

1080 return dtype 

1081 

1082 

1083def _append_numpy_string_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None: 

1084 """Append numpy string length keys to arrow metadata. 

1085 

1086 All column types are handled, but the metadata is only modified for 

1087 string and byte columns. 

1088 

1089 Parameters 

1090 ---------- 

1091 metadata : `dict` [`bytes`, `str`] 

1092 Metadata dictionary; modified in place. 

1093 name : `str` 

1094 Column name. 

1095 dtype : `np.dtype` 

1096 Numpy dtype. 

1097 """ 

1098 import numpy as np 

1099 

1100 if dtype.type is np.str_: 

1101 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize // 4) 

1102 elif dtype.type is np.bytes_: 

1103 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize) 

1104 

1105 

1106def _append_numpy_multidim_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None: 

1107 """Append numpy multi-dimensional shapes to arrow metadata. 

1108 

1109 All column types are handled, but the metadata is only modified for 

1110 multi-dimensional columns. 

1111 

1112 Parameters 

1113 ---------- 

1114 metadata : `dict` [`bytes`, `str`] 

1115 Metadata dictionary; modified in place. 

1116 name : `str` 

1117 Column name. 

1118 dtype : `np.dtype` 

1119 Numpy dtype. 

1120 """ 

1121 if len(dtype.shape) > 1: 

1122 metadata[f"lsst::arrow::shape::{name}".encode()] = str(dtype.shape) 

1123 

1124 

1125def _multidim_shape_from_metadata(metadata: dict[bytes, bytes], list_size: int, name: str) -> tuple[int, ...]: 

1126 """Retrieve the shape from the metadata, if available. 

1127 

1128 Parameters 

1129 ---------- 

1130 metadata : `dict` [`bytes`, `bytes`] 

1131 Metadata dictionary. 

1132 list_size : `int` 

1133 Size of the list datatype. 

1134 name : `str` 

1135 Column name. 

1136 

1137 Returns 

1138 ------- 

1139 shape : `tuple` [`int`] 

1140 Shape associated with the column. 

1141 

1142 Raises 

1143 ------ 

1144 RuntimeError 

1145 Raised if metadata is found but has incorrect format. 

1146 """ 

1147 md_name = f"lsst::arrow::shape::{name}" 

1148 if (encoded := md_name.encode("UTF-8")) in metadata: 

1149 groups = re.search(r"\((.*)\)", metadata[encoded].decode("UTF-8")) 

1150 if groups is None: 

1151 raise RuntimeError("Illegal value found in metadata.") 

1152 shape = tuple(int(x) for x in groups[1].split(",") if x != "") 

1153 else: 

1154 shape = (list_size,) 

1155 

1156 return shape 

1157 

1158 

1159def _schema_to_dtype_list(schema: pa.Schema) -> list[tuple[str, tuple[Any] | str]]: 

1160 """Convert a pyarrow schema to a numpy dtype. 

1161 

1162 Parameters 

1163 ---------- 

1164 schema : `pyarrow.Schema` 

1165 Input pyarrow schema. 

1166 

1167 Returns 

1168 ------- 

1169 dtype_list: `list` [`tuple`] 

1170 A list with name, type pairs. 

1171 """ 

1172 metadata = schema.metadata if schema.metadata is not None else {} 

1173 

1174 dtype: list[Any] = [] 

1175 for name in schema.names: 

1176 t = schema.field(name).type 

1177 if isinstance(t, pa.FixedSizeListType): 

1178 shape = _multidim_shape_from_metadata(metadata, t.list_size, name) 

1179 dtype.append((name, (t.value_type.to_pandas_dtype(), shape))) 

1180 elif t not in (pa.string(), pa.binary()): 

1181 dtype.append((name, t.to_pandas_dtype())) 

1182 else: 

1183 dtype.append((name, _arrow_string_to_numpy_dtype(schema, name))) 

1184 

1185 return dtype 

1186 

1187 

1188def _numpy_dtype_to_arrow_types(dtype: np.dtype) -> list[Any]: 

1189 """Convert a numpy dtype to a list of arrow types. 

1190 

1191 Parameters 

1192 ---------- 

1193 dtype : `numpy.dtype` 

1194 Numpy dtype to convert. 

1195 

1196 Returns 

1197 ------- 

1198 type_list : `list` [`object`] 

1199 Converted list of arrow types. 

1200 """ 

1201 from math import prod 

1202 

1203 import numpy as np 

1204 

1205 type_list: list[Any] = [] 

1206 if dtype.names is None: 

1207 return type_list 

1208 

1209 for name in dtype.names: 

1210 dt = dtype[name] 

1211 arrow_type: Any 

1212 if len(dt.shape) > 0: 

1213 arrow_type = pa.list_( 

1214 pa.from_numpy_dtype(cast(tuple[np.dtype, tuple[int, ...]], dt.subdtype)[0].type), 

1215 prod(dt.shape), 

1216 ) 

1217 else: 

1218 arrow_type = pa.from_numpy_dtype(dt.type) 

1219 type_list.append((name, arrow_type)) 

1220 

1221 return type_list 

1222 

1223 

1224def _numpy_dict_to_dtype(numpy_dict: dict[str, np.ndarray]) -> tuple[np.dtype, int]: 

1225 """Extract equivalent table dtype from dict of numpy arrays. 

1226 

1227 Parameters 

1228 ---------- 

1229 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

1230 Dict with keys as the column names, values as the arrays. 

1231 

1232 Returns 

1233 ------- 

1234 dtype : `numpy.dtype` 

1235 dtype of equivalent table. 

1236 rowcount : `int` 

1237 Number of rows in the table. 

1238 

1239 Raises 

1240 ------ 

1241 ValueError if columns in numpy_dict have unequal numbers of rows. 

1242 """ 

1243 import numpy as np 

1244 

1245 dtype_list = [] 

1246 rowcount = 0 

1247 for name, col in numpy_dict.items(): 

1248 if rowcount == 0: 

1249 rowcount = len(col) 

1250 if len(col) != rowcount: 

1251 raise ValueError(f"Column {name} has a different number of rows.") 

1252 if len(col.shape) == 1: 

1253 dtype_list.append((name, col.dtype)) 

1254 else: 

1255 dtype_list.append((name, (col.dtype, col.shape[1:]))) 

1256 dtype = np.dtype(dtype_list) 

1257 

1258 return (dtype, rowcount) 

1259 

1260 

1261def _numpy_style_arrays_to_arrow_arrays( 

1262 dtype: np.dtype, 

1263 rowcount: int, 

1264 np_style_arrays: dict[str, np.ndarray] | np.ndarray | atable.Table, 

1265 schema: pa.Schema, 

1266) -> list[pa.Array]: 

1267 """Convert numpy-style arrays to arrow arrays. 

1268 

1269 Parameters 

1270 ---------- 

1271 dtype : `numpy.dtype` 

1272 Numpy dtype of input table/arrays. 

1273 rowcount : `int` 

1274 Number of rows in input table/arrays. 

1275 np_style_arrays : `dict` [`str`, `np.ndarray`] or `np.ndarray` 

1276 or `astropy.table.Table` 

1277 Arrays to convert to arrow. 

1278 schema : `pyarrow.Schema` 

1279 Schema of arrow table. 

1280 

1281 Returns 

1282 ------- 

1283 arrow_arrays : `list` [`pyarrow.Array`] 

1284 List of converted pyarrow arrays. 

1285 """ 

1286 import numpy as np 

1287 

1288 arrow_arrays: list[pa.Array] = [] 

1289 if dtype.names is None: 

1290 return arrow_arrays 

1291 

1292 for name in dtype.names: 

1293 dt = dtype[name] 

1294 val: Any 

1295 if len(dt.shape) > 0: 

1296 if rowcount > 0: 

1297 val = np.split(np_style_arrays[name].ravel(), rowcount) 

1298 else: 

1299 val = [] 

1300 else: 

1301 val = np_style_arrays[name] 

1302 

1303 try: 

1304 arrow_arrays.append(pa.array(val, type=schema.field(name).type)) 

1305 except pa.ArrowNotImplementedError as err: 

1306 # Check if val is big-endian. 

1307 if (np.little_endian and val.dtype.byteorder == ">") or ( 

1308 not np.little_endian and val.dtype.byteorder == "=" 

1309 ): 

1310 # We need to convert the array to little-endian. 

1311 val2 = val.byteswap() 

1312 val2.dtype = val2.dtype.newbyteorder("<") 

1313 arrow_arrays.append(pa.array(val2, type=schema.field(name).type)) 

1314 else: 

1315 # This failed for some other reason so raise the exception. 

1316 raise err 

1317 

1318 return arrow_arrays 

1319 

1320 

1321def compute_row_group_size(schema: pa.Schema, target_size: int = TARGET_ROW_GROUP_BYTES) -> int: 

1322 """Compute approximate row group size for a given arrow schema. 

1323 

1324 Given a schema, this routine will compute the number of rows in a row group 

1325 that targets the persisted size on disk (or smaller). The exact size on 

1326 disk depends on the compression settings and ratios; typical binary data 

1327 tables will have around 15-20% compression with the pyarrow default 

1328 ``snappy`` compression algorithm. 

1329 

1330 Parameters 

1331 ---------- 

1332 schema : `pyarrow.Schema` 

1333 Arrow table schema. 

1334 target_size : `int`, optional 

1335 The target size (in bytes). 

1336 

1337 Returns 

1338 ------- 

1339 row_group_size : `int` 

1340 Number of rows per row group to hit the target size. 

1341 """ 

1342 bit_width = 0 

1343 

1344 metadata = schema.metadata if schema.metadata is not None else {} 

1345 

1346 for name in schema.names: 

1347 t = schema.field(name).type 

1348 

1349 if t in (pa.string(), pa.binary()): 

1350 md_name = f"lsst::arrow::len::{name}" 

1351 

1352 if (encoded := md_name.encode("UTF-8")) in metadata: 

1353 # String/bytes length from header. 

1354 strlen = int(schema.metadata[encoded]) 

1355 else: 

1356 # We don't know the string width, so guess something. 

1357 strlen = 10 

1358 

1359 # Assuming UTF-8 encoding, and very few wide characters. 

1360 t_width = 8 * strlen 

1361 elif isinstance(t, pa.FixedSizeListType): 

1362 if t.value_type == pa.null(): 

1363 t_width = 0 

1364 else: 

1365 t_width = t.list_size * t.value_type.bit_width 

1366 elif t == pa.null(): 

1367 t_width = 0 

1368 elif isinstance(t, pa.ListType): 

1369 if t.value_type == pa.null(): 

1370 t_width = 0 

1371 else: 

1372 # This is a variable length list, just choose 

1373 # something arbitrary. 

1374 t_width = 10 * t.value_type.bit_width 

1375 else: 

1376 t_width = t.bit_width 

1377 

1378 bit_width += t_width 

1379 

1380 # Insist it is at least 1 byte wide to avoid any divide-by-zero errors. 

1381 if bit_width < 8: 

1382 bit_width = 8 

1383 

1384 byte_width = bit_width // 8 

1385 

1386 return target_size // byte_width