Coverage for python/lsst/daf/butler/formatters/parquet.py: 12%

443 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-03 09:15 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ( 

25 "ParquetFormatter", 

26 "arrow_to_pandas", 

27 "arrow_to_astropy", 

28 "arrow_to_numpy", 

29 "arrow_to_numpy_dict", 

30 "pandas_to_arrow", 

31 "pandas_to_astropy", 

32 "astropy_to_arrow", 

33 "numpy_to_arrow", 

34 "numpy_to_astropy", 

35 "numpy_dict_to_arrow", 

36 "arrow_schema_to_pandas_index", 

37 "DataFrameSchema", 

38 "ArrowAstropySchema", 

39 "ArrowNumpySchema", 

40 "compute_row_group_size", 

41) 

42 

43import collections.abc 

44import itertools 

45import json 

46import re 

47from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, cast 

48 

49import pyarrow as pa 

50import pyarrow.parquet as pq 

51from lsst.daf.butler import Formatter 

52from lsst.utils.introspection import get_full_type_name 

53from lsst.utils.iteration import ensure_iterable 

54 

55if TYPE_CHECKING: 

56 import astropy.table as atable 

57 import numpy as np 

58 import pandas as pd 

59 

60TARGET_ROW_GROUP_BYTES = 1_000_000_000 

61 

62 

63class ParquetFormatter(Formatter): 

64 """Interface for reading and writing Arrow Table objects to and from 

65 Parquet files. 

66 """ 

67 

68 extension = ".parq" 

69 

70 def read(self, component: Optional[str] = None) -> Any: 

71 # Docstring inherited from Formatter.read. 

72 schema = pq.read_schema(self.fileDescriptor.location.path) 

73 

74 if component in ("columns", "schema"): 

75 # The schema will be translated to column format 

76 # depending on the input type. 

77 return schema 

78 elif component == "rowcount": 

79 # Get the rowcount from the metadata if possible, otherwise count. 

80 if b"lsst::arrow::rowcount" in schema.metadata: 

81 return int(schema.metadata[b"lsst::arrow::rowcount"]) 

82 

83 temp_table = pq.read_table( 

84 self.fileDescriptor.location.path, 

85 columns=[schema.names[0]], 

86 use_threads=False, 

87 use_pandas_metadata=False, 

88 ) 

89 

90 return len(temp_table[schema.names[0]]) 

91 

92 par_columns = None 

93 if self.fileDescriptor.parameters: 

94 par_columns = self.fileDescriptor.parameters.pop("columns", None) 

95 if par_columns: 

96 has_pandas_multi_index = False 

97 if b"pandas" in schema.metadata: 

98 md = json.loads(schema.metadata[b"pandas"]) 

99 if len(md["column_indexes"]) > 1: 

100 has_pandas_multi_index = True 

101 

102 if not has_pandas_multi_index: 

103 # Ensure uniqueness, keeping order. 

104 par_columns = list(dict.fromkeys(ensure_iterable(par_columns))) 

105 file_columns = [name for name in schema.names if not name.startswith("__")] 

106 

107 for par_column in par_columns: 

108 if par_column not in file_columns: 

109 raise ValueError( 

110 f"Column {par_column} specified in parameters not available in parquet file." 

111 ) 

112 else: 

113 par_columns = _standardize_multi_index_columns( 

114 arrow_schema_to_pandas_index(schema), 

115 par_columns, 

116 ) 

117 

118 if len(self.fileDescriptor.parameters): 

119 raise ValueError( 

120 f"Unsupported parameters {self.fileDescriptor.parameters} in ArrowTable read." 

121 ) 

122 

123 metadata = schema.metadata if schema.metadata is not None else {} 

124 arrow_table = pq.read_table( 

125 self.fileDescriptor.location.path, 

126 columns=par_columns, 

127 use_threads=False, 

128 use_pandas_metadata=(b"pandas" in metadata), 

129 ) 

130 

131 return arrow_table 

132 

133 def write(self, inMemoryDataset: Any) -> None: 

134 import numpy as np 

135 from astropy.table import Table as astropyTable 

136 

137 arrow_table = None 

138 if isinstance(inMemoryDataset, pa.Table): 

139 # This will be the most likely match. 

140 arrow_table = inMemoryDataset 

141 elif isinstance(inMemoryDataset, astropyTable): 

142 arrow_table = astropy_to_arrow(inMemoryDataset) 

143 elif isinstance(inMemoryDataset, np.ndarray): 

144 arrow_table = numpy_to_arrow(inMemoryDataset) 

145 elif isinstance(inMemoryDataset, dict): 

146 try: 

147 arrow_table = numpy_dict_to_arrow(inMemoryDataset) 

148 except (TypeError, AttributeError) as e: 

149 raise ValueError( 

150 "Input dict for inMemoryDataset does not appear to be a dict of numpy arrays." 

151 ) from e 

152 else: 

153 if hasattr(inMemoryDataset, "to_parquet"): 

154 # This may be a pandas DataFrame 

155 try: 

156 import pandas as pd 

157 except ImportError: 

158 pd = None 

159 

160 if pd is not None and isinstance(inMemoryDataset, pd.DataFrame): 

161 arrow_table = pandas_to_arrow(inMemoryDataset) 

162 

163 if arrow_table is None: 

164 raise ValueError( 

165 f"Unsupported type {get_full_type_name(inMemoryDataset)} of " 

166 "inMemoryDataset for ParquetFormatter." 

167 ) 

168 

169 row_group_size = compute_row_group_size(arrow_table.schema) 

170 

171 location = self.makeUpdatedLocation(self.fileDescriptor.location) 

172 

173 pq.write_table(arrow_table, location.path, row_group_size=row_group_size) 

174 

175 

176def arrow_to_pandas(arrow_table: pa.Table) -> pd.DataFrame: 

177 """Convert a pyarrow table to a pandas DataFrame. 

178 

179 Parameters 

180 ---------- 

181 arrow_table : `pyarrow.Table` 

182 Input arrow table to convert. If the table has ``pandas`` metadata 

183 in the schema it will be used in the construction of the 

184 ``DataFrame``. 

185 

186 Returns 

187 ------- 

188 dataframe : `pandas.DataFrame` 

189 Converted pandas dataframe. 

190 """ 

191 return arrow_table.to_pandas(use_threads=False, integer_object_nulls=True) 

192 

193 

194def arrow_to_astropy(arrow_table: pa.Table) -> atable.Table: 

195 """Convert a pyarrow table to an `astropy.Table`. 

196 

197 Parameters 

198 ---------- 

199 arrow_table : `pyarrow.Table` 

200 Input arrow table to convert. If the table has astropy unit 

201 metadata in the schema it will be used in the construction 

202 of the ``astropy.Table``. 

203 

204 Returns 

205 ------- 

206 table : `astropy.Table` 

207 Converted astropy table. 

208 """ 

209 from astropy.table import Table 

210 

211 astropy_table = Table(arrow_to_numpy_dict(arrow_table)) 

212 

213 metadata = arrow_table.schema.metadata if arrow_table.schema.metadata is not None else {} 

214 

215 _apply_astropy_metadata(astropy_table, metadata) 

216 

217 return astropy_table 

218 

219 

220def arrow_to_numpy(arrow_table: pa.Table) -> np.ndarray: 

221 """Convert a pyarrow table to a structured numpy array. 

222 

223 Parameters 

224 ---------- 

225 arrow_table : `pyarrow.Table` 

226 Input arrow table. 

227 

228 Returns 

229 ------- 

230 array : `numpy.ndarray` (N,) 

231 Numpy array table with N rows and the same column names 

232 as the input arrow table. 

233 """ 

234 import numpy as np 

235 

236 numpy_dict = arrow_to_numpy_dict(arrow_table) 

237 

238 dtype = [] 

239 for name, col in numpy_dict.items(): 

240 if len(shape := numpy_dict[name].shape) <= 1: 

241 dtype.append((name, col.dtype)) 

242 else: 

243 dtype.append((name, (col.dtype, shape[1:]))) 

244 

245 array = np.rec.fromarrays(numpy_dict.values(), dtype=dtype) 

246 

247 return array 

248 

249 

250def arrow_to_numpy_dict(arrow_table: pa.Table) -> dict[str, np.ndarray]: 

251 """Convert a pyarrow table to a dict of numpy arrays. 

252 

253 Parameters 

254 ---------- 

255 arrow_table : `pyarrow.Table` 

256 Input arrow table. 

257 

258 Returns 

259 ------- 

260 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

261 Dict with keys as the column names, values as the arrays. 

262 """ 

263 import numpy as np 

264 

265 schema = arrow_table.schema 

266 metadata = schema.metadata if schema.metadata is not None else {} 

267 

268 numpy_dict = {} 

269 

270 for name in schema.names: 

271 t = schema.field(name).type 

272 

273 if arrow_table[name].null_count == 0: 

274 # Regular non-masked column 

275 col = arrow_table[name].to_numpy() 

276 else: 

277 # For a masked column, we need to ask arrow to fill the null 

278 # values with an appropriately typed value before conversion. 

279 # Then we apply the mask to get a masked array of the correct type. 

280 

281 if t in (pa.string(), pa.binary()): 

282 dummy = "" 

283 else: 

284 dummy = t.to_pandas_dtype()(0) 

285 

286 col = np.ma.masked_array( 

287 data=arrow_table[name].fill_null(dummy).to_numpy(), 

288 mask=arrow_table[name].is_null().to_numpy(), 

289 ) 

290 

291 if t in (pa.string(), pa.binary()): 

292 col = col.astype(_arrow_string_to_numpy_dtype(schema, name, col)) 

293 elif isinstance(t, pa.FixedSizeListType): 

294 if len(col) > 0: 

295 col = np.stack(col) 

296 else: 

297 # this is an empty column, and needs to be coerced to type. 

298 col = col.astype(t.value_type.to_pandas_dtype()) 

299 

300 shape = _multidim_shape_from_metadata(metadata, t.list_size, name) 

301 col = col.reshape((len(arrow_table), *shape)) 

302 

303 numpy_dict[name] = col 

304 

305 return numpy_dict 

306 

307 

308def _numpy_dict_to_numpy(numpy_dict: dict[str, np.ndarray]) -> np.ndarray: 

309 """Convert a dict of numpy arrays to a structured numpy array. 

310 

311 Parameters 

312 ---------- 

313 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

314 Dict with keys as the column names, values as the arrays. 

315 

316 Returns 

317 ------- 

318 array : `numpy.ndarray` (N,) 

319 Numpy array table with N rows and columns names from the dict keys. 

320 """ 

321 return arrow_to_numpy(numpy_dict_to_arrow(numpy_dict)) 

322 

323 

324def _numpy_to_numpy_dict(np_array: np.ndarray) -> dict[str, np.ndarray]: 

325 """Convert a structured numpy array to a dict of numpy arrays. 

326 

327 Parameters 

328 ---------- 

329 np_array : `numpy.ndarray` 

330 Input numpy array with multiple fields. 

331 

332 Returns 

333 ------- 

334 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

335 Dict with keys as the column names, values as the arrays. 

336 """ 

337 return arrow_to_numpy_dict(numpy_to_arrow(np_array)) 

338 

339 

340def numpy_to_arrow(np_array: np.ndarray) -> pa.Table: 

341 """Convert a numpy array table to an arrow table. 

342 

343 Parameters 

344 ---------- 

345 np_array : `numpy.ndarray` 

346 Input numpy array with multiple fields. 

347 

348 Returns 

349 ------- 

350 arrow_table : `pyarrow.Table` 

351 Converted arrow table. 

352 """ 

353 type_list = _numpy_dtype_to_arrow_types(np_array.dtype) 

354 

355 md = {} 

356 md[b"lsst::arrow::rowcount"] = str(len(np_array)) 

357 

358 for name in np_array.dtype.names: 

359 _append_numpy_string_metadata(md, name, np_array.dtype[name]) 

360 _append_numpy_multidim_metadata(md, name, np_array.dtype[name]) 

361 

362 schema = pa.schema(type_list, metadata=md) 

363 

364 arrays = _numpy_style_arrays_to_arrow_arrays( 

365 np_array.dtype, 

366 len(np_array), 

367 np_array, 

368 schema, 

369 ) 

370 

371 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

372 

373 return arrow_table 

374 

375 

376def numpy_dict_to_arrow(numpy_dict: dict[str, np.ndarray]) -> pa.Table: 

377 """Convert a dict of numpy arrays to an arrow table. 

378 

379 Parameters 

380 ---------- 

381 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

382 Dict with keys as the column names, values as the arrays. 

383 

384 Returns 

385 ------- 

386 arrow_table : `pyarrow.Table` 

387 Converted arrow table. 

388 

389 Raises 

390 ------ 

391 ValueError if columns in numpy_dict have unequal numbers of rows. 

392 """ 

393 dtype, rowcount = _numpy_dict_to_dtype(numpy_dict) 

394 type_list = _numpy_dtype_to_arrow_types(dtype) 

395 

396 md = {} 

397 md[b"lsst::arrow::rowcount"] = str(rowcount) 

398 

399 if dtype.names is not None: 

400 for name in dtype.names: 

401 _append_numpy_string_metadata(md, name, dtype[name]) 

402 _append_numpy_multidim_metadata(md, name, dtype[name]) 

403 

404 schema = pa.schema(type_list, metadata=md) 

405 

406 arrays = _numpy_style_arrays_to_arrow_arrays( 

407 dtype, 

408 rowcount, 

409 numpy_dict, 

410 schema, 

411 ) 

412 

413 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

414 

415 return arrow_table 

416 

417 

418def astropy_to_arrow(astropy_table: atable.Table) -> pa.Table: 

419 """Convert an astropy table to an arrow table. 

420 

421 Parameters 

422 ---------- 

423 astropy_table : `astropy.Table` 

424 Input astropy table. 

425 

426 Returns 

427 ------- 

428 arrow_table : `pyarrow.Table` 

429 Converted arrow table. 

430 """ 

431 from astropy.table import meta 

432 

433 type_list = _numpy_dtype_to_arrow_types(astropy_table.dtype) 

434 

435 md = {} 

436 md[b"lsst::arrow::rowcount"] = str(len(astropy_table)) 

437 

438 for name in astropy_table.dtype.names: 

439 _append_numpy_string_metadata(md, name, astropy_table.dtype[name]) 

440 _append_numpy_multidim_metadata(md, name, astropy_table.dtype[name]) 

441 

442 meta_yaml = meta.get_yaml_from_table(astropy_table) 

443 meta_yaml_str = "\n".join(meta_yaml) 

444 md[b"table_meta_yaml"] = meta_yaml_str 

445 

446 schema = pa.schema(type_list, metadata=md) 

447 

448 arrays = _numpy_style_arrays_to_arrow_arrays( 

449 astropy_table.dtype, 

450 len(astropy_table), 

451 astropy_table, 

452 schema, 

453 ) 

454 

455 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

456 

457 return arrow_table 

458 

459 

460def _astropy_to_numpy_dict(astropy_table: atable.Table) -> dict[str, np.ndarray]: 

461 """Convert an astropy table to an arrow table. 

462 

463 Parameters 

464 ---------- 

465 astropy_table : `astropy.Table` 

466 Input astropy table. 

467 

468 Returns 

469 ------- 

470 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

471 Dict with keys as the column names, values as the arrays. 

472 """ 

473 return arrow_to_numpy_dict(astropy_to_arrow(astropy_table)) 

474 

475 

476def pandas_to_arrow(dataframe: pd.DataFrame, default_length: int = 10) -> pa.Table: 

477 """Convert a pandas dataframe to an arrow table. 

478 

479 Parameters 

480 ---------- 

481 dataframe : `pandas.DataFrame` 

482 Input pandas dataframe. 

483 default_length : `int`, optional 

484 Default string length when not in metadata or can be inferred 

485 from column. 

486 

487 Returns 

488 ------- 

489 arrow_table : `pyarrow.Table` 

490 Converted arrow table. 

491 """ 

492 arrow_table = pa.Table.from_pandas(dataframe) 

493 

494 # Update the metadata 

495 md = arrow_table.schema.metadata 

496 

497 md[b"lsst::arrow::rowcount"] = str(arrow_table.num_rows) 

498 

499 # We loop through the arrow table columns because the datatypes have 

500 # been checked and converted from pandas objects. 

501 for name in arrow_table.column_names: 

502 if not name.startswith("__"): 

503 if arrow_table[name].type == pa.string(): 

504 if len(arrow_table[name]) > 0: 

505 strlen = max(len(row.as_py()) for row in arrow_table[name] if row.is_valid) 

506 else: 

507 strlen = default_length 

508 md[f"lsst::arrow::len::{name}".encode("UTF-8")] = str(strlen) 

509 

510 arrow_table = arrow_table.replace_schema_metadata(md) 

511 

512 return arrow_table 

513 

514 

515def pandas_to_astropy(dataframe: pd.DataFrame) -> atable.Table: 

516 """Convert a pandas dataframe to an astropy table, preserving indexes. 

517 

518 Parameters 

519 ---------- 

520 dataframe : `pandas.DataFrame` 

521 Input pandas dataframe. 

522 

523 Returns 

524 ------- 

525 astropy_table : `astropy.table.Table` 

526 Converted astropy table. 

527 """ 

528 import pandas as pd 

529 from astropy.table import Table 

530 

531 if isinstance(dataframe.columns, pd.MultiIndex): 

532 raise ValueError("Cannot convert a multi-index dataframe to an astropy table.") 

533 

534 return Table.from_pandas(dataframe, index=True) 

535 

536 

537def _pandas_to_numpy_dict(dataframe: pd.DataFrame) -> dict[str, np.ndarray]: 

538 """Convert a pandas dataframe to an dict of numpy arrays. 

539 

540 Parameters 

541 ---------- 

542 dataframe : `pandas.DataFrame` 

543 Input pandas dataframe. 

544 

545 Returns 

546 ------- 

547 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

548 Dict with keys as the column names, values as the arrays. 

549 """ 

550 return arrow_to_numpy_dict(pandas_to_arrow(dataframe)) 

551 

552 

553def numpy_to_astropy(np_array: np.ndarray) -> atable.Table: 

554 """Convert a numpy table to an astropy table. 

555 

556 Parameters 

557 ---------- 

558 np_array : `numpy.ndarray` 

559 Input numpy array with multiple fields. 

560 

561 Returns 

562 ------- 

563 astropy_table : `astropy.table.Table` 

564 Converted astropy table. 

565 """ 

566 from astropy.table import Table 

567 

568 return Table(data=np_array, copy=False) 

569 

570 

571def arrow_schema_to_pandas_index(schema: pa.Schema) -> pd.Index | pd.MultiIndex: 

572 """Convert an arrow schema to a pandas index/multiindex. 

573 

574 Parameters 

575 ---------- 

576 schema : `pyarrow.Schema` 

577 Input pyarrow schema. 

578 

579 Returns 

580 ------- 

581 index : `pandas.Index` or `pandas.MultiIndex` 

582 Converted pandas index. 

583 """ 

584 import pandas as pd 

585 

586 if b"pandas" in schema.metadata: 

587 md = json.loads(schema.metadata[b"pandas"]) 

588 indexes = md["column_indexes"] 

589 len_indexes = len(indexes) 

590 else: 

591 len_indexes = 0 

592 

593 if len_indexes <= 1: 

594 return pd.Index(name for name in schema.names if not name.startswith("__")) 

595 else: 

596 raw_columns = _split_multi_index_column_names(len(indexes), schema.names) 

597 return pd.MultiIndex.from_tuples(raw_columns, names=[f["name"] for f in indexes]) 

598 

599 

600def arrow_schema_to_column_list(schema: pa.Schema) -> list[str]: 

601 """Convert an arrow schema to a list of string column names. 

602 

603 Parameters 

604 ---------- 

605 schema : `pyarrow.Schema` 

606 Input pyarrow schema. 

607 

608 Returns 

609 ------- 

610 column_list : `list` [`str`] 

611 Converted list of column names. 

612 """ 

613 return [name for name in schema.names] 

614 

615 

616class DataFrameSchema: 

617 """Wrapper class for a schema for a pandas DataFrame. 

618 

619 Parameters 

620 ---------- 

621 dataframe : `pandas.DataFrame` 

622 Dataframe to turn into a schema. 

623 """ 

624 

625 def __init__(self, dataframe: pd.DataFrame) -> None: 

626 self._schema = dataframe.loc[[False] * len(dataframe)] 

627 

628 @classmethod 

629 def from_arrow(cls, schema: pa.Schema) -> DataFrameSchema: 

630 """Convert an arrow schema into a `DataFrameSchema`. 

631 

632 Parameters 

633 ---------- 

634 schema : `pyarrow.Schema` 

635 The pyarrow schema to convert. 

636 

637 Returns 

638 ------- 

639 dataframe_schema : `DataFrameSchema` 

640 Converted dataframe schema. 

641 """ 

642 empty_table = pa.Table.from_pylist([] * len(schema.names), schema=schema) 

643 

644 return cls(empty_table.to_pandas()) 

645 

646 def to_arrow_schema(self) -> pa.Schema: 

647 """Convert to an arrow schema. 

648 

649 Returns 

650 ------- 

651 arrow_schema : `pyarrow.Schema` 

652 Converted pyarrow schema. 

653 """ 

654 arrow_table = pa.Table.from_pandas(self._schema) 

655 

656 return arrow_table.schema 

657 

658 def to_arrow_numpy_schema(self) -> ArrowNumpySchema: 

659 """Convert to an `ArrowNumpySchema`. 

660 

661 Returns 

662 ------- 

663 arrow_numpy_schema : `ArrowNumpySchema` 

664 Converted arrow numpy schema. 

665 """ 

666 return ArrowNumpySchema.from_arrow(self.to_arrow_schema()) 

667 

668 def to_arrow_astropy_schema(self) -> ArrowAstropySchema: 

669 """Convert to an ArrowAstropySchema. 

670 

671 Returns 

672 ------- 

673 arrow_astropy_schema : `ArrowAstropySchema` 

674 Converted arrow astropy schema. 

675 """ 

676 return ArrowAstropySchema.from_arrow(self.to_arrow_schema()) 

677 

678 @property 

679 def schema(self) -> np.dtype: 

680 return self._schema 

681 

682 def __repr__(self) -> str: 

683 return repr(self._schema) 

684 

685 def __eq__(self, other: object) -> bool: 

686 if not isinstance(other, DataFrameSchema): 

687 return NotImplemented 

688 

689 return self._schema.equals(other._schema) 

690 

691 

692class ArrowAstropySchema: 

693 """Wrapper class for a schema for an astropy table. 

694 

695 Parameters 

696 ---------- 

697 astropy_table : `astropy.table.Table` 

698 Input astropy table. 

699 """ 

700 

701 def __init__(self, astropy_table: atable.Table) -> None: 

702 self._schema = astropy_table[:0] 

703 

704 @classmethod 

705 def from_arrow(cls, schema: pa.Schema) -> ArrowAstropySchema: 

706 """Convert an arrow schema into a ArrowAstropySchema. 

707 

708 Parameters 

709 ---------- 

710 schema : `pyarrow.Schema` 

711 Input pyarrow schema. 

712 

713 Returns 

714 ------- 

715 astropy_schema : `ArrowAstropySchema` 

716 Converted arrow astropy schema. 

717 """ 

718 import numpy as np 

719 from astropy.table import Table 

720 

721 dtype = _schema_to_dtype_list(schema) 

722 

723 data = np.zeros(0, dtype=dtype) 

724 astropy_table = Table(data=data) 

725 

726 metadata = schema.metadata if schema.metadata is not None else {} 

727 

728 _apply_astropy_metadata(astropy_table, metadata) 

729 

730 return cls(astropy_table) 

731 

732 def to_arrow_schema(self) -> pa.Schema: 

733 """Convert to an arrow schema. 

734 

735 Returns 

736 ------- 

737 arrow_schema : `pyarrow.Schema` 

738 Converted pyarrow schema. 

739 """ 

740 return astropy_to_arrow(self._schema).schema 

741 

742 def to_dataframe_schema(self) -> DataFrameSchema: 

743 """Convert to a DataFrameSchema. 

744 

745 Returns 

746 ------- 

747 dataframe_schema : `DataFrameSchema` 

748 Converted dataframe schema. 

749 """ 

750 return DataFrameSchema.from_arrow(astropy_to_arrow(self._schema).schema) 

751 

752 def to_arrow_numpy_schema(self) -> ArrowNumpySchema: 

753 """Convert to an `ArrowNumpySchema`. 

754 

755 Returns 

756 ------- 

757 arrow_numpy_schema : `ArrowNumpySchema` 

758 Converted arrow numpy schema. 

759 """ 

760 return ArrowNumpySchema.from_arrow(astropy_to_arrow(self._schema).schema) 

761 

762 @property 

763 def schema(self) -> atable.Table: 

764 return self._schema 

765 

766 def __repr__(self) -> str: 

767 return repr(self._schema) 

768 

769 def __eq__(self, other: object) -> bool: 

770 if not isinstance(other, ArrowAstropySchema): 

771 return NotImplemented 

772 

773 # If this comparison passes then the two tables have the 

774 # same column names. 

775 if self._schema.dtype != other._schema.dtype: 

776 return False 

777 

778 for name in self._schema.columns: 

779 if not self._schema[name].unit == other._schema[name].unit: 

780 return False 

781 if not self._schema[name].description == other._schema[name].description: 

782 return False 

783 if not self._schema[name].format == other._schema[name].format: 

784 return False 

785 

786 return True 

787 

788 

789class ArrowNumpySchema: 

790 """Wrapper class for a schema for a numpy ndarray. 

791 

792 Parameters 

793 ---------- 

794 numpy_dtype : `numpy.dtype` 

795 Numpy dtype to convert. 

796 """ 

797 

798 def __init__(self, numpy_dtype: np.dtype) -> None: 

799 self._dtype = numpy_dtype 

800 

801 @classmethod 

802 def from_arrow(cls, schema: pa.Schema) -> ArrowNumpySchema: 

803 """Convert an arrow schema into an `ArrowNumpySchema`. 

804 

805 Parameters 

806 ---------- 

807 schema : `pyarrow.Schema` 

808 Pyarrow schema to convert. 

809 

810 Returns 

811 ------- 

812 numpy_schema : `ArrowNumpySchema` 

813 Converted arrow numpy schema. 

814 """ 

815 import numpy as np 

816 

817 dtype = _schema_to_dtype_list(schema) 

818 

819 return cls(np.dtype(dtype)) 

820 

821 def to_arrow_astropy_schema(self) -> ArrowAstropySchema: 

822 """Convert to an `ArrowAstropySchema`. 

823 

824 Returns 

825 ------- 

826 astropy_schema : `ArrowAstropySchema` 

827 Converted arrow astropy schema. 

828 """ 

829 import numpy as np 

830 

831 return ArrowAstropySchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema) 

832 

833 def to_dataframe_schema(self) -> DataFrameSchema: 

834 """Convert to a `DataFrameSchema`. 

835 

836 Returns 

837 ------- 

838 dataframe_schema : `DataFrameSchema` 

839 Converted dataframe schema. 

840 """ 

841 import numpy as np 

842 

843 return DataFrameSchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema) 

844 

845 def to_arrow_schema(self) -> pa.Schema: 

846 """Convert to a `pyarrow.Schema`. 

847 

848 Returns 

849 ------- 

850 arrow_schema : `pyarrow.Schema` 

851 Converted pyarrow schema. 

852 """ 

853 import numpy as np 

854 

855 return numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema 

856 

857 @property 

858 def schema(self) -> np.dtype: 

859 return self._dtype 

860 

861 def __repr__(self) -> str: 

862 return repr(self._dtype) 

863 

864 def __eq__(self, other: object) -> bool: 

865 if not isinstance(other, ArrowNumpySchema): 

866 return NotImplemented 

867 

868 if not self._dtype == other._dtype: 

869 return False 

870 

871 return True 

872 

873 

874def _split_multi_index_column_names(n: int, names: Iterable[str]) -> List[Sequence[str]]: 

875 """Split a string that represents a multi-index column. 

876 

877 PyArrow maps Pandas' multi-index column names (which are tuples in Python) 

878 to flat strings on disk. This routine exists to reconstruct the original 

879 tuple. 

880 

881 Parameters 

882 ---------- 

883 n : `int` 

884 Number of levels in the `pandas.MultiIndex` that is being 

885 reconstructed. 

886 names : `~collections.abc.Iterable` [`str`] 

887 Strings to be split. 

888 

889 Returns 

890 ------- 

891 column_names : `list` [`tuple` [`str`]] 

892 A list of multi-index column name tuples. 

893 """ 

894 column_names: List[Sequence[str]] = [] 

895 

896 pattern = re.compile(r"\({}\)".format(", ".join(["'(.*)'"] * n))) 

897 for name in names: 

898 m = re.search(pattern, name) 

899 if m is not None: 

900 column_names.append(m.groups()) 

901 

902 return column_names 

903 

904 

905def _standardize_multi_index_columns( 

906 pd_index: pd.MultiIndex, 

907 columns: Any, 

908 stringify: bool = True, 

909) -> list[str | Sequence[Any]]: 

910 """Transform a dictionary/iterable index from a multi-index column list 

911 into a string directly understandable by PyArrow. 

912 

913 Parameters 

914 ---------- 

915 pd_index : `pandas.MultiIndex` 

916 Pandas multi-index. 

917 columns : `list` [`tuple`] or `dict` [`str`, `str` or `list` [`str`]] 

918 Columns to standardize. 

919 stringify : `bool`, optional 

920 Should the column names be stringified? 

921 

922 Returns 

923 ------- 

924 names : `list` [`str`] 

925 Stringified representation of a multi-index column name. 

926 """ 

927 index_level_names = tuple(pd_index.names) 

928 

929 names: list[str | Sequence[Any]] = [] 

930 

931 if isinstance(columns, list): 

932 for requested in columns: 

933 if not isinstance(requested, tuple): 

934 raise ValueError( 

935 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. " 

936 f"Instead got a {get_full_type_name(requested)}." 

937 ) 

938 if stringify: 

939 names.append(str(requested)) 

940 else: 

941 names.append(requested) 

942 else: 

943 if not isinstance(columns, collections.abc.Mapping): 

944 raise ValueError( 

945 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. " 

946 f"Instead got a {get_full_type_name(columns)}." 

947 ) 

948 if not set(index_level_names).issuperset(columns.keys()): 

949 raise ValueError( 

950 f"Cannot use dict with keys {set(columns.keys())} to select columns from {index_level_names}." 

951 ) 

952 factors = [ 

953 ensure_iterable(columns.get(level, pd_index.levels[i])) 

954 for i, level in enumerate(index_level_names) 

955 ] 

956 for requested in itertools.product(*factors): 

957 for i, value in enumerate(requested): 

958 if value not in pd_index.levels[i]: 

959 raise ValueError(f"Unrecognized value {value!r} for index {index_level_names[i]!r}.") 

960 if stringify: 

961 names.append(str(requested)) 

962 else: 

963 names.append(requested) 

964 

965 return names 

966 

967 

968def _apply_astropy_metadata(astropy_table: atable.Table, metadata: dict) -> None: 

969 """Apply any astropy metadata from the schema metadata. 

970 

971 Parameters 

972 ---------- 

973 astropy_table : `astropy.table.Table` 

974 Table to apply metadata. 

975 metadata : `dict` [`bytes`] 

976 Metadata dict. 

977 """ 

978 from astropy.table import meta 

979 

980 meta_yaml = metadata.get(b"table_meta_yaml", None) 

981 if meta_yaml: 

982 meta_yaml = meta_yaml.decode("UTF8").split("\n") 

983 meta_hdr = meta.get_header_from_yaml(meta_yaml) 

984 

985 # Set description, format, unit, meta from the column 

986 # metadata that was serialized with the table. 

987 header_cols = {x["name"]: x for x in meta_hdr["datatype"]} 

988 for col in astropy_table.columns.values(): 

989 for attr in ("description", "format", "unit", "meta"): 

990 if attr in header_cols[col.name]: 

991 setattr(col, attr, header_cols[col.name][attr]) 

992 

993 if "meta" in meta_hdr: 

994 astropy_table.meta.update(meta_hdr["meta"]) 

995 

996 

997def _arrow_string_to_numpy_dtype( 

998 schema: pa.Schema, name: str, numpy_column: np.ndarray | None = None, default_length: int = 10 

999) -> str: 

1000 """Get the numpy dtype string associated with an arrow column. 

1001 

1002 Parameters 

1003 ---------- 

1004 schema : `pyarrow.Schema` 

1005 Arrow table schema. 

1006 name : `str` 

1007 Column name. 

1008 numpy_column : `numpy.ndarray`, optional 

1009 Column to determine numpy string dtype. 

1010 default_length : `int`, optional 

1011 Default string length when not in metadata or can be inferred 

1012 from column. 

1013 

1014 Returns 

1015 ------- 

1016 dtype_str : `str` 

1017 Numpy dtype string. 

1018 """ 

1019 # Special-case for string and binary columns 

1020 md_name = f"lsst::arrow::len::{name}" 

1021 strlen = default_length 

1022 metadata = schema.metadata if schema.metadata is not None else {} 

1023 if (encoded := md_name.encode("UTF-8")) in metadata: 

1024 # String/bytes length from header. 

1025 strlen = int(schema.metadata[encoded]) 

1026 elif numpy_column is not None: 

1027 if len(numpy_column) > 0: 

1028 strlen = max(len(row) for row in numpy_column) 

1029 

1030 dtype = f"U{strlen}" if schema.field(name).type == pa.string() else f"|S{strlen}" 

1031 

1032 return dtype 

1033 

1034 

1035def _append_numpy_string_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None: 

1036 """Append numpy string length keys to arrow metadata. 

1037 

1038 All column types are handled, but the metadata is only modified for 

1039 string and byte columns. 

1040 

1041 Parameters 

1042 ---------- 

1043 metadata : `dict` [`bytes`, `str`] 

1044 Metadata dictionary; modified in place. 

1045 name : `str` 

1046 Column name. 

1047 dtype : `np.dtype` 

1048 Numpy dtype. 

1049 """ 

1050 import numpy as np 

1051 

1052 if dtype.type is np.str_: 

1053 metadata[f"lsst::arrow::len::{name}".encode("UTF-8")] = str(dtype.itemsize // 4) 

1054 elif dtype.type is np.bytes_: 

1055 metadata[f"lsst::arrow::len::{name}".encode("UTF-8")] = str(dtype.itemsize) 

1056 

1057 

1058def _append_numpy_multidim_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None: 

1059 """Append numpy multi-dimensional shapes to arrow metadata. 

1060 

1061 All column types are handled, but the metadata is only modified for 

1062 multi-dimensional columns. 

1063 

1064 Parameters 

1065 ---------- 

1066 metadata : `dict` [`bytes`, `str`] 

1067 Metadata dictionary; modified in place. 

1068 name : `str` 

1069 Column name. 

1070 dtype : `np.dtype` 

1071 Numpy dtype. 

1072 """ 

1073 if len(dtype.shape) > 1: 

1074 metadata[f"lsst::arrow::shape::{name}".encode("UTF-8")] = str(dtype.shape) 

1075 

1076 

1077def _multidim_shape_from_metadata(metadata: dict[bytes, bytes], list_size: int, name: str) -> tuple[int, ...]: 

1078 """Retrieve the shape from the metadata, if available. 

1079 

1080 Parameters 

1081 ---------- 

1082 metadata : `dict` [`bytes`, `bytes`] 

1083 Metadata dictionary. 

1084 list_size : `int` 

1085 Size of the list datatype. 

1086 name : `str` 

1087 Column name. 

1088 

1089 Returns 

1090 ------- 

1091 shape : `tuple` [`int`] 

1092 Shape associated with the column. 

1093 

1094 Raises 

1095 ------ 

1096 RuntimeError 

1097 Raised if metadata is found but has incorrect format. 

1098 """ 

1099 md_name = f"lsst::arrow::shape::{name}" 

1100 if (encoded := md_name.encode("UTF-8")) in metadata: 

1101 groups = re.search(r"\((.*)\)", metadata[encoded].decode("UTF-8")) 

1102 if groups is None: 

1103 raise RuntimeError("Illegal value found in metadata.") 

1104 shape = tuple(int(x) for x in groups[1].split(",") if x != "") 

1105 else: 

1106 shape = (list_size,) 

1107 

1108 return shape 

1109 

1110 

1111def _schema_to_dtype_list(schema: pa.Schema) -> list[tuple[str, tuple[Any] | str]]: 

1112 """Convert a pyarrow schema to a numpy dtype. 

1113 

1114 Parameters 

1115 ---------- 

1116 schema : `pyarrow.Schema` 

1117 Input pyarrow schema. 

1118 

1119 Returns 

1120 ------- 

1121 dtype_list: `list` [`tuple`] 

1122 A list with name, type pairs. 

1123 """ 

1124 metadata = schema.metadata if schema.metadata is not None else {} 

1125 

1126 dtype: list[Any] = [] 

1127 for name in schema.names: 

1128 t = schema.field(name).type 

1129 if isinstance(t, pa.FixedSizeListType): 

1130 shape = _multidim_shape_from_metadata(metadata, t.list_size, name) 

1131 dtype.append((name, (t.value_type.to_pandas_dtype(), shape))) 

1132 elif t not in (pa.string(), pa.binary()): 

1133 dtype.append((name, t.to_pandas_dtype())) 

1134 else: 

1135 dtype.append((name, _arrow_string_to_numpy_dtype(schema, name))) 

1136 

1137 return dtype 

1138 

1139 

1140def _numpy_dtype_to_arrow_types(dtype: np.dtype) -> list[Any]: 

1141 """Convert a numpy dtype to a list of arrow types. 

1142 

1143 Parameters 

1144 ---------- 

1145 dtype : `numpy.dtype` 

1146 Numpy dtype to convert. 

1147 

1148 Returns 

1149 ------- 

1150 type_list : `list` [`object`] 

1151 Converted list of arrow types. 

1152 """ 

1153 from math import prod 

1154 

1155 import numpy as np 

1156 

1157 type_list: list[Any] = [] 

1158 if dtype.names is None: 

1159 return type_list 

1160 

1161 for name in dtype.names: 

1162 dt = dtype[name] 

1163 arrow_type: Any 

1164 if len(dt.shape) > 0: 

1165 arrow_type = pa.list_( 

1166 pa.from_numpy_dtype(cast(tuple[np.dtype, tuple[int, ...]], dt.subdtype)[0].type), 

1167 prod(dt.shape), 

1168 ) 

1169 else: 

1170 arrow_type = pa.from_numpy_dtype(dt.type) 

1171 type_list.append((name, arrow_type)) 

1172 

1173 return type_list 

1174 

1175 

1176def _numpy_dict_to_dtype(numpy_dict: dict[str, np.ndarray]) -> tuple[np.dtype, int]: 

1177 """Extract equivalent table dtype from dict of numpy arrays. 

1178 

1179 Parameters 

1180 ---------- 

1181 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

1182 Dict with keys as the column names, values as the arrays. 

1183 

1184 Returns 

1185 ------- 

1186 dtype : `numpy.dtype` 

1187 dtype of equivalent table. 

1188 rowcount : `int` 

1189 Number of rows in the table. 

1190 

1191 Raises 

1192 ------ 

1193 ValueError if columns in numpy_dict have unequal numbers of rows. 

1194 """ 

1195 import numpy as np 

1196 

1197 dtype_list = [] 

1198 rowcount = 0 

1199 for name, col in numpy_dict.items(): 

1200 if rowcount == 0: 

1201 rowcount = len(col) 

1202 if len(col) != rowcount: 

1203 raise ValueError(f"Column {name} has a different number of rows.") 

1204 if len(col.shape) == 1: 

1205 dtype_list.append((name, col.dtype)) 

1206 else: 

1207 dtype_list.append((name, (col.dtype, col.shape[1:]))) 

1208 dtype = np.dtype(dtype_list) 

1209 

1210 return (dtype, rowcount) 

1211 

1212 

1213def _numpy_style_arrays_to_arrow_arrays( 

1214 dtype: np.dtype, 

1215 rowcount: int, 

1216 np_style_arrays: dict[str, np.ndarray] | np.ndarray | atable.Table, 

1217 schema: pa.Schema, 

1218) -> list[pa.Array]: 

1219 """Convert numpy-style arrays to arrow arrays. 

1220 

1221 Parameters 

1222 ---------- 

1223 dtype : `numpy.dtype` 

1224 Numpy dtype of input table/arrays. 

1225 rowcount : `int` 

1226 Number of rows in input table/arrays. 

1227 np_style_arrays : `dict` [`str`, `np.ndarray`] or `np.ndarray` 

1228 or `astropy.table.Table` 

1229 Arrays to convert to arrow. 

1230 schema : `pyarrow.Schema` 

1231 Schema of arrow table. 

1232 

1233 Returns 

1234 ------- 

1235 arrow_arrays : `list` [`pyarrow.Array`] 

1236 List of converted pyarrow arrays. 

1237 """ 

1238 import numpy as np 

1239 

1240 arrow_arrays: list[pa.Array] = [] 

1241 if dtype.names is None: 

1242 return arrow_arrays 

1243 

1244 for name in dtype.names: 

1245 dt = dtype[name] 

1246 val: Any 

1247 if len(dt.shape) > 0: 

1248 if rowcount > 0: 

1249 val = np.split(np_style_arrays[name].ravel(), rowcount) 

1250 else: 

1251 val = [] 

1252 else: 

1253 val = np_style_arrays[name] 

1254 

1255 try: 

1256 arrow_arrays.append(pa.array(val, type=schema.field(name).type)) 

1257 except pa.ArrowNotImplementedError as err: 

1258 # Check if val is big-endian. 

1259 if (np.little_endian and val.dtype.byteorder == ">") or ( 

1260 not np.little_endian and val.dtype.byteorder == "=" 

1261 ): 

1262 # We need to convert the array to little-endian. 

1263 val2 = val.byteswap() 

1264 val2.dtype = val2.dtype.newbyteorder("<") 

1265 arrow_arrays.append(pa.array(val2, type=schema.field(name).type)) 

1266 else: 

1267 # This failed for some other reason so raise the exception. 

1268 raise err 

1269 

1270 return arrow_arrays 

1271 

1272 

1273def compute_row_group_size(schema: pa.Schema, target_size: int = TARGET_ROW_GROUP_BYTES) -> int: 

1274 """Compute approximate row group size for a given arrow schema. 

1275 

1276 Given a schema, this routine will compute the number of rows in a row group 

1277 that targets the persisted size on disk (or smaller). The exact size on 

1278 disk depends on the compression settings and ratios; typical binary data 

1279 tables will have around 15-20% compression with the pyarrow default 

1280 ``snappy`` compression algorithm. 

1281 

1282 Parameters 

1283 ---------- 

1284 schema : `pyarrow.Schema` 

1285 Arrow table schema. 

1286 target_size : `int`, optional 

1287 The target size (in bytes). 

1288 

1289 Returns 

1290 ------- 

1291 row_group_size : `int` 

1292 Number of rows per row group to hit the target size. 

1293 """ 

1294 bit_width = 0 

1295 

1296 metadata = schema.metadata if schema.metadata is not None else {} 

1297 

1298 for name in schema.names: 

1299 t = schema.field(name).type 

1300 

1301 if t in (pa.string(), pa.binary()): 

1302 md_name = f"lsst::arrow::len::{name}" 

1303 

1304 if (encoded := md_name.encode("UTF-8")) in metadata: 

1305 # String/bytes length from header. 

1306 strlen = int(schema.metadata[encoded]) 

1307 else: 

1308 # We don't know the string width, so guess something. 

1309 strlen = 10 

1310 

1311 # Assuming UTF-8 encoding, and very few wide characters. 

1312 t_width = 8 * strlen 

1313 elif isinstance(t, pa.FixedSizeListType): 

1314 if t.value_type == pa.null(): 

1315 t_width = 0 

1316 else: 

1317 t_width = t.list_size * t.value_type.bit_width 

1318 elif t == pa.null(): 

1319 t_width = 0 

1320 elif isinstance(t, pa.ListType): 

1321 if t.value_type == pa.null(): 

1322 t_width = 0 

1323 else: 

1324 # This is a variable length list, just choose 

1325 # something arbitrary. 

1326 t_width = 10 * t.value_type.bit_width 

1327 else: 

1328 t_width = t.bit_width 

1329 

1330 bit_width += t_width 

1331 

1332 # Insist it is at least 1 byte wide to avoid any divide-by-zero errors. 

1333 if bit_width < 8: 

1334 bit_width = 8 

1335 

1336 byte_width = bit_width // 8 

1337 

1338 return target_size // byte_width