Coverage for python/lsst/daf/butler/formatters/parquet.py: 13%

411 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-04-07 00:58 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ( 

25 "ParquetFormatter", 

26 "arrow_to_pandas", 

27 "arrow_to_astropy", 

28 "arrow_to_numpy", 

29 "arrow_to_numpy_dict", 

30 "pandas_to_arrow", 

31 "pandas_to_astropy", 

32 "astropy_to_arrow", 

33 "numpy_to_arrow", 

34 "numpy_to_astropy", 

35 "numpy_dict_to_arrow", 

36 "arrow_schema_to_pandas_index", 

37 "DataFrameSchema", 

38 "ArrowAstropySchema", 

39 "ArrowNumpySchema", 

40) 

41 

42import collections.abc 

43import itertools 

44import json 

45import re 

46from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, Union, cast 

47 

48import pyarrow as pa 

49import pyarrow.parquet as pq 

50from lsst.daf.butler import Formatter 

51from lsst.utils.introspection import get_full_type_name 

52from lsst.utils.iteration import ensure_iterable 

53 

54if TYPE_CHECKING: 

55 import astropy.table as atable 

56 import numpy as np 

57 import pandas as pd 

58 

59 

60class ParquetFormatter(Formatter): 

61 """Interface for reading and writing Arrow Table objects to and from 

62 Parquet files. 

63 """ 

64 

65 extension = ".parq" 

66 

67 def read(self, component: Optional[str] = None) -> Any: 

68 # Docstring inherited from Formatter.read. 

69 schema = pq.read_schema(self.fileDescriptor.location.path) 

70 

71 if component in ("columns", "schema"): 

72 # The schema will be translated to column format 

73 # depending on the input type. 

74 return schema 

75 elif component == "rowcount": 

76 # Get the rowcount from the metadata if possible, otherwise count. 

77 if b"lsst::arrow::rowcount" in schema.metadata: 

78 return int(schema.metadata[b"lsst::arrow::rowcount"]) 

79 

80 temp_table = pq.read_table( 

81 self.fileDescriptor.location.path, 

82 columns=[schema.names[0]], 

83 use_threads=False, 

84 use_pandas_metadata=False, 

85 ) 

86 

87 return len(temp_table[schema.names[0]]) 

88 

89 par_columns = None 

90 if self.fileDescriptor.parameters: 

91 par_columns = self.fileDescriptor.parameters.pop("columns", None) 

92 if par_columns: 

93 has_pandas_multi_index = False 

94 if b"pandas" in schema.metadata: 

95 md = json.loads(schema.metadata[b"pandas"]) 

96 if len(md["column_indexes"]) > 1: 

97 has_pandas_multi_index = True 

98 

99 if not has_pandas_multi_index: 

100 # Ensure uniqueness, keeping order. 

101 par_columns = list(dict.fromkeys(ensure_iterable(par_columns))) 

102 file_columns = [name for name in schema.names if not name.startswith("__")] 

103 

104 for par_column in par_columns: 

105 if par_column not in file_columns: 

106 raise ValueError( 

107 f"Column {par_column} specified in parameters not available in parquet file." 

108 ) 

109 else: 

110 par_columns = _standardize_multi_index_columns(schema, par_columns) 

111 

112 if len(self.fileDescriptor.parameters): 

113 raise ValueError( 

114 f"Unsupported parameters {self.fileDescriptor.parameters} in ArrowTable read." 

115 ) 

116 

117 metadata = schema.metadata if schema.metadata is not None else {} 

118 arrow_table = pq.read_table( 

119 self.fileDescriptor.location.path, 

120 columns=par_columns, 

121 use_threads=False, 

122 use_pandas_metadata=(b"pandas" in metadata), 

123 ) 

124 

125 return arrow_table 

126 

127 def write(self, inMemoryDataset: Any) -> None: 

128 import numpy as np 

129 from astropy.table import Table as astropyTable 

130 

131 arrow_table = None 

132 if isinstance(inMemoryDataset, pa.Table): 

133 # This will be the most likely match. 

134 arrow_table = inMemoryDataset 

135 elif isinstance(inMemoryDataset, astropyTable): 

136 arrow_table = astropy_to_arrow(inMemoryDataset) 

137 elif isinstance(inMemoryDataset, np.ndarray): 

138 arrow_table = numpy_to_arrow(inMemoryDataset) 

139 elif isinstance(inMemoryDataset, dict): 

140 try: 

141 arrow_table = numpy_dict_to_arrow(inMemoryDataset) 

142 except (TypeError, AttributeError) as e: 

143 raise ValueError( 

144 "Input dict for inMemoryDataset does not appear to be a dict of numpy arrays." 

145 ) from e 

146 else: 

147 if hasattr(inMemoryDataset, "to_parquet"): 

148 # This may be a pandas DataFrame 

149 try: 

150 import pandas as pd 

151 except ImportError: 

152 pd = None 

153 

154 if pd is not None and isinstance(inMemoryDataset, pd.DataFrame): 

155 arrow_table = pandas_to_arrow(inMemoryDataset) 

156 

157 if arrow_table is None: 

158 raise ValueError( 

159 f"Unsupported type {get_full_type_name(inMemoryDataset)} of " 

160 "inMemoryDataset for ParquetFormatter." 

161 ) 

162 

163 location = self.makeUpdatedLocation(self.fileDescriptor.location) 

164 

165 pq.write_table(arrow_table, location.path) 

166 

167 

168def arrow_to_pandas(arrow_table: pa.Table) -> pd.DataFrame: 

169 """Convert a pyarrow table to a pandas DataFrame. 

170 

171 Parameters 

172 ---------- 

173 arrow_table : `pyarrow.Table` 

174 Input arrow table to convert. If the table has ``pandas`` metadata 

175 in the schema it will be used in the construction of the 

176 ``DataFrame``. 

177 

178 Returns 

179 ------- 

180 dataframe : `pandas.DataFrame` 

181 Converted pandas dataframe. 

182 """ 

183 return arrow_table.to_pandas(use_threads=False, integer_object_nulls=True) 

184 

185 

186def arrow_to_astropy(arrow_table: pa.Table) -> atable.Table: 

187 """Convert a pyarrow table to an `astropy.Table`. 

188 

189 Parameters 

190 ---------- 

191 arrow_table : `pyarrow.Table` 

192 Input arrow table to convert. If the table has astropy unit 

193 metadata in the schema it will be used in the construction 

194 of the ``astropy.Table``. 

195 

196 Returns 

197 ------- 

198 table : `astropy.Table` 

199 Converted astropy table. 

200 """ 

201 from astropy.table import Table 

202 

203 astropy_table = Table(arrow_to_numpy_dict(arrow_table)) 

204 

205 metadata = arrow_table.schema.metadata if arrow_table.schema.metadata is not None else {} 

206 

207 _apply_astropy_metadata(astropy_table, metadata) 

208 

209 return astropy_table 

210 

211 

212def arrow_to_numpy(arrow_table: pa.Table) -> np.ndarray: 

213 """Convert a pyarrow table to a structured numpy array. 

214 

215 Parameters 

216 ---------- 

217 arrow_table : `pyarrow.Table` 

218 Input arrow table. 

219 

220 Returns 

221 ------- 

222 array : `numpy.ndarray` (N,) 

223 Numpy array table with N rows and the same column names 

224 as the input arrow table. 

225 """ 

226 import numpy as np 

227 

228 numpy_dict = arrow_to_numpy_dict(arrow_table) 

229 

230 dtype = [] 

231 for name, col in numpy_dict.items(): 

232 if len(shape := numpy_dict[name].shape) <= 1: 

233 dtype.append((name, col.dtype)) 

234 else: 

235 dtype.append((name, (col.dtype, shape[1:]))) 

236 

237 array = np.rec.fromarrays(numpy_dict.values(), dtype=dtype) 

238 

239 return array 

240 

241 

242def arrow_to_numpy_dict(arrow_table: pa.Table) -> dict[str, np.ndarray]: 

243 """Convert a pyarrow table to a dict of numpy arrays. 

244 

245 Parameters 

246 ---------- 

247 arrow_table : `pyarrow.Table` 

248 Input arrow table. 

249 

250 Returns 

251 ------- 

252 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

253 Dict with keys as the column names, values as the arrays. 

254 """ 

255 import numpy as np 

256 

257 schema = arrow_table.schema 

258 metadata = schema.metadata if schema.metadata is not None else {} 

259 

260 numpy_dict = {} 

261 

262 for name in schema.names: 

263 t = schema.field(name).type 

264 

265 if arrow_table[name].null_count == 0: 

266 # Regular non-masked column 

267 col = arrow_table[name].to_numpy() 

268 else: 

269 # For a masked column, we need to ask arrow to fill the null 

270 # values with an appropriately typed value before conversion. 

271 # Then we apply the mask to get a masked array of the correct type. 

272 

273 if t in (pa.string(), pa.binary()): 

274 dummy = "" 

275 else: 

276 dummy = t.to_pandas_dtype()(0) 

277 

278 col = np.ma.masked_array( 

279 data=arrow_table[name].fill_null(dummy).to_numpy(), 

280 mask=arrow_table[name].is_null().to_numpy(), 

281 ) 

282 

283 if t in (pa.string(), pa.binary()): 

284 col = col.astype(_arrow_string_to_numpy_dtype(schema, name, col)) 

285 elif isinstance(t, pa.FixedSizeListType): 

286 if len(col) > 0: 

287 col = np.stack(col) 

288 else: 

289 # this is an empty column, and needs to be coerced to type. 

290 col = col.astype(t.value_type.to_pandas_dtype()) 

291 

292 shape = _multidim_shape_from_metadata(metadata, t.list_size, name) 

293 col = col.reshape((len(arrow_table), *shape)) 

294 

295 numpy_dict[name] = col 

296 

297 return numpy_dict 

298 

299 

300def _numpy_dict_to_numpy(numpy_dict: dict[str, np.ndarray]) -> np.ndarray: 

301 """Convert a dict of numpy arrays to a structured numpy array. 

302 

303 Parameters 

304 ---------- 

305 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

306 Dict with keys as the column names, values as the arrays. 

307 

308 Returns 

309 ------- 

310 array : `numpy.ndarray` (N,) 

311 Numpy array table with N rows and columns names from the dict keys. 

312 """ 

313 return arrow_to_numpy(numpy_dict_to_arrow(numpy_dict)) 

314 

315 

316def _numpy_to_numpy_dict(np_array: np.ndarray) -> dict[str, np.ndarray]: 

317 """Convert a structured numpy array to a dict of numpy arrays. 

318 

319 Parameters 

320 ---------- 

321 np_array : `numpy.ndarray` 

322 Input numpy array with multiple fields. 

323 

324 Returns 

325 ------- 

326 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

327 Dict with keys as the column names, values as the arrays. 

328 """ 

329 return arrow_to_numpy_dict(numpy_to_arrow(np_array)) 

330 

331 

332def numpy_to_arrow(np_array: np.ndarray) -> pa.Table: 

333 """Convert a numpy array table to an arrow table. 

334 

335 Parameters 

336 ---------- 

337 np_array : `numpy.ndarray` 

338 Input numpy array with multiple fields. 

339 

340 Returns 

341 ------- 

342 arrow_table : `pyarrow.Table` 

343 Converted arrow table. 

344 """ 

345 type_list = _numpy_dtype_to_arrow_types(np_array.dtype) 

346 

347 md = {} 

348 md[b"lsst::arrow::rowcount"] = str(len(np_array)) 

349 

350 for name in np_array.dtype.names: 

351 _append_numpy_string_metadata(md, name, np_array.dtype[name]) 

352 _append_numpy_multidim_metadata(md, name, np_array.dtype[name]) 

353 

354 schema = pa.schema(type_list, metadata=md) 

355 

356 arrays = _numpy_style_arrays_to_arrow_arrays( 

357 np_array.dtype, 

358 len(np_array), 

359 np_array, 

360 schema, 

361 ) 

362 

363 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

364 

365 return arrow_table 

366 

367 

368def numpy_dict_to_arrow(numpy_dict: dict[str, np.ndarray]) -> pa.Table: 

369 """Convert a dict of numpy arrays to an arrow table. 

370 

371 Parameters 

372 ---------- 

373 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

374 Dict with keys as the column names, values as the arrays. 

375 

376 Returns 

377 ------- 

378 arrow_table : `pyarrow.Table` 

379 Converted arrow table. 

380 

381 Raises 

382 ------ 

383 ValueError if columns in numpy_dict have unequal numbers of rows. 

384 """ 

385 dtype, rowcount = _numpy_dict_to_dtype(numpy_dict) 

386 type_list = _numpy_dtype_to_arrow_types(dtype) 

387 

388 md = {} 

389 md[b"lsst::arrow::rowcount"] = str(rowcount) 

390 

391 if dtype.names is not None: 

392 for name in dtype.names: 

393 _append_numpy_string_metadata(md, name, dtype[name]) 

394 _append_numpy_multidim_metadata(md, name, dtype[name]) 

395 

396 schema = pa.schema(type_list, metadata=md) 

397 

398 arrays = _numpy_style_arrays_to_arrow_arrays( 

399 dtype, 

400 rowcount, 

401 numpy_dict, 

402 schema, 

403 ) 

404 

405 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

406 

407 return arrow_table 

408 

409 

410def astropy_to_arrow(astropy_table: atable.Table) -> pa.Table: 

411 """Convert an astropy table to an arrow table. 

412 

413 Parameters 

414 ---------- 

415 astropy_table : `astropy.Table` 

416 Input astropy table. 

417 

418 Returns 

419 ------- 

420 arrow_table : `pyarrow.Table` 

421 Converted arrow table. 

422 """ 

423 from astropy.table import meta 

424 

425 type_list = _numpy_dtype_to_arrow_types(astropy_table.dtype) 

426 

427 md = {} 

428 md[b"lsst::arrow::rowcount"] = str(len(astropy_table)) 

429 

430 for name in astropy_table.dtype.names: 

431 _append_numpy_string_metadata(md, name, astropy_table.dtype[name]) 

432 _append_numpy_multidim_metadata(md, name, astropy_table.dtype[name]) 

433 

434 meta_yaml = meta.get_yaml_from_table(astropy_table) 

435 meta_yaml_str = "\n".join(meta_yaml) 

436 md[b"table_meta_yaml"] = meta_yaml_str 

437 

438 schema = pa.schema(type_list, metadata=md) 

439 

440 arrays = _numpy_style_arrays_to_arrow_arrays( 

441 astropy_table.dtype, 

442 len(astropy_table), 

443 astropy_table, 

444 schema, 

445 ) 

446 

447 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

448 

449 return arrow_table 

450 

451 

452def _astropy_to_numpy_dict(astropy_table: atable.Table) -> dict[str, np.ndarray]: 

453 """Convert an astropy table to an arrow table. 

454 

455 Parameters 

456 ---------- 

457 astropy_table : `astropy.Table` 

458 Input astropy table. 

459 

460 Returns 

461 ------- 

462 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

463 Dict with keys as the column names, values as the arrays. 

464 """ 

465 return arrow_to_numpy_dict(astropy_to_arrow(astropy_table)) 

466 

467 

468def pandas_to_arrow(dataframe: pd.DataFrame, default_length: int = 10) -> pa.Table: 

469 """Convert a pandas dataframe to an arrow table. 

470 

471 Parameters 

472 ---------- 

473 dataframe : `pandas.DataFrame` 

474 Input pandas dataframe. 

475 default_length : `int`, optional 

476 Default string length when not in metadata or can be inferred 

477 from column. 

478 

479 Returns 

480 ------- 

481 arrow_table : `pyarrow.Table` 

482 Converted arrow table. 

483 """ 

484 arrow_table = pa.Table.from_pandas(dataframe) 

485 

486 # Update the metadata 

487 md = arrow_table.schema.metadata 

488 

489 md[b"lsst::arrow::rowcount"] = str(arrow_table.num_rows) 

490 

491 # We loop through the arrow table columns because the datatypes have 

492 # been checked and converted from pandas objects. 

493 for name in arrow_table.column_names: 

494 if not name.startswith("__"): 

495 if arrow_table[name].type == pa.string(): 

496 if len(arrow_table[name]) > 0: 

497 strlen = max(len(row.as_py()) for row in arrow_table[name] if row.is_valid) 

498 else: 

499 strlen = default_length 

500 md[f"lsst::arrow::len::{name}".encode("UTF-8")] = str(strlen) 

501 

502 arrow_table = arrow_table.replace_schema_metadata(md) 

503 

504 return arrow_table 

505 

506 

507def pandas_to_astropy(dataframe: pd.DataFrame) -> atable.Table: 

508 """Convert a pandas dataframe to an astropy table, preserving indexes. 

509 

510 Parameters 

511 ---------- 

512 dataframe : `pandas.DataFrame` 

513 Input pandas dataframe. 

514 

515 Returns 

516 ------- 

517 astropy_table : `astropy.table.Table` 

518 Converted astropy table. 

519 """ 

520 import pandas as pd 

521 from astropy.table import Table 

522 

523 if isinstance(dataframe.columns, pd.MultiIndex): 

524 raise ValueError("Cannot convert a multi-index dataframe to an astropy table.") 

525 

526 return Table.from_pandas(dataframe, index=True) 

527 

528 

529def _pandas_to_numpy_dict(dataframe: pd.DataFrame) -> dict[str, np.ndarray]: 

530 """Convert a pandas dataframe to an dict of numpy arrays. 

531 

532 Parameters 

533 ---------- 

534 dataframe : `pandas.DataFrame` 

535 Input pandas dataframe. 

536 

537 Returns 

538 ------- 

539 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

540 Dict with keys as the column names, values as the arrays. 

541 """ 

542 return arrow_to_numpy_dict(pandas_to_arrow(dataframe)) 

543 

544 

545def numpy_to_astropy(np_array: np.ndarray) -> atable.Table: 

546 """Convert a numpy table to an astropy table. 

547 

548 Parameters 

549 ---------- 

550 np_array : `numpy.ndarray` 

551 Input numpy array with multiple fields. 

552 

553 Returns 

554 ------- 

555 astropy_table : `astropy.table.Table` 

556 Converted astropy table. 

557 """ 

558 from astropy.table import Table 

559 

560 return Table(data=np_array, copy=False) 

561 

562 

563def arrow_schema_to_pandas_index(schema: pa.Schema) -> pd.Index | pd.MultiIndex: 

564 """Convert an arrow schema to a pandas index/multiindex. 

565 

566 Parameters 

567 ---------- 

568 schema : `pyarrow.Schema` 

569 Input pyarrow schema. 

570 

571 Returns 

572 ------- 

573 index : `pandas.Index` or `pandas.MultiIndex` 

574 Converted pandas index. 

575 """ 

576 import pandas as pd 

577 

578 if b"pandas" in schema.metadata: 

579 md = json.loads(schema.metadata[b"pandas"]) 

580 indexes = md["column_indexes"] 

581 len_indexes = len(indexes) 

582 else: 

583 len_indexes = 0 

584 

585 if len_indexes <= 1: 

586 return pd.Index(name for name in schema.names if not name.startswith("__")) 

587 else: 

588 raw_columns = _split_multi_index_column_names(len(indexes), schema.names) 

589 return pd.MultiIndex.from_tuples(raw_columns, names=[f["name"] for f in indexes]) 

590 

591 

592def arrow_schema_to_column_list(schema: pa.Schema) -> list[str]: 

593 """Convert an arrow schema to a list of string column names. 

594 

595 Parameters 

596 ---------- 

597 schema : `pyarrow.Schema` 

598 Input pyarrow schema. 

599 

600 Returns 

601 ------- 

602 column_list : `list` [`str`] 

603 Converted list of column names. 

604 """ 

605 return [name for name in schema.names] 

606 

607 

608class DataFrameSchema: 

609 """Wrapper class for a schema for a pandas DataFrame. 

610 

611 Parameters 

612 ---------- 

613 dataframe : `pandas.DataFrame` 

614 Dataframe to turn into a schema. 

615 """ 

616 

617 def __init__(self, dataframe: pd.DataFrame) -> None: 

618 self._schema = dataframe.loc[[False] * len(dataframe)] 

619 

620 @classmethod 

621 def from_arrow(cls, schema: pa.Schema) -> DataFrameSchema: 

622 """Convert an arrow schema into a `DataFrameSchema`. 

623 

624 Parameters 

625 ---------- 

626 schema : `pyarrow.Schema` 

627 The pyarrow schema to convert. 

628 

629 Returns 

630 ------- 

631 dataframe_schema : `DataFrameSchema` 

632 Converted dataframe schema. 

633 """ 

634 empty_table = pa.Table.from_pylist([] * len(schema.names), schema=schema) 

635 

636 return cls(empty_table.to_pandas()) 

637 

638 def to_arrow_schema(self) -> pa.Schema: 

639 """Convert to an arrow schema. 

640 

641 Returns 

642 ------- 

643 arrow_schema : `pyarrow.Schema` 

644 Converted pyarrow schema. 

645 """ 

646 arrow_table = pa.Table.from_pandas(self._schema) 

647 

648 return arrow_table.schema 

649 

650 def to_arrow_numpy_schema(self) -> ArrowNumpySchema: 

651 """Convert to an `ArrowNumpySchema`. 

652 

653 Returns 

654 ------- 

655 arrow_numpy_schema : `ArrowNumpySchema` 

656 Converted arrow numpy schema. 

657 """ 

658 return ArrowNumpySchema.from_arrow(self.to_arrow_schema()) 

659 

660 def to_arrow_astropy_schema(self) -> ArrowAstropySchema: 

661 """Convert to an ArrowAstropySchema. 

662 

663 Returns 

664 ------- 

665 arrow_astropy_schema : `ArrowAstropySchema` 

666 Converted arrow astropy schema. 

667 """ 

668 return ArrowAstropySchema.from_arrow(self.to_arrow_schema()) 

669 

670 @property 

671 def schema(self) -> np.dtype: 

672 return self._schema 

673 

674 def __repr__(self) -> str: 

675 return repr(self._schema) 

676 

677 def __eq__(self, other: object) -> bool: 

678 if not isinstance(other, DataFrameSchema): 

679 return NotImplemented 

680 

681 return self._schema.equals(other._schema) 

682 

683 

684class ArrowAstropySchema: 

685 """Wrapper class for a schema for an astropy table. 

686 

687 Parameters 

688 ---------- 

689 astropy_table : `astropy.table.Table` 

690 Input astropy table. 

691 """ 

692 

693 def __init__(self, astropy_table: atable.Table) -> None: 

694 self._schema = astropy_table[:0] 

695 

696 @classmethod 

697 def from_arrow(cls, schema: pa.Schema) -> ArrowAstropySchema: 

698 """Convert an arrow schema into a ArrowAstropySchema. 

699 

700 Parameters 

701 ---------- 

702 schema : `pyarrow.Schema` 

703 Input pyarrow schema. 

704 

705 Returns 

706 ------- 

707 astropy_schema : `ArrowAstropySchema` 

708 Converted arrow astropy schema. 

709 """ 

710 import numpy as np 

711 from astropy.table import Table 

712 

713 dtype = _schema_to_dtype_list(schema) 

714 

715 data = np.zeros(0, dtype=dtype) 

716 astropy_table = Table(data=data) 

717 

718 metadata = schema.metadata if schema.metadata is not None else {} 

719 

720 _apply_astropy_metadata(astropy_table, metadata) 

721 

722 return cls(astropy_table) 

723 

724 def to_arrow_schema(self) -> pa.Schema: 

725 """Convert to an arrow schema. 

726 

727 Returns 

728 ------- 

729 arrow_schema : `pyarrow.Schema` 

730 Converted pyarrow schema. 

731 """ 

732 return astropy_to_arrow(self._schema).schema 

733 

734 def to_dataframe_schema(self) -> DataFrameSchema: 

735 """Convert to a DataFrameSchema. 

736 

737 Returns 

738 ------- 

739 dataframe_schema : `DataFrameSchema` 

740 Converted dataframe schema. 

741 """ 

742 return DataFrameSchema.from_arrow(astropy_to_arrow(self._schema).schema) 

743 

744 def to_arrow_numpy_schema(self) -> ArrowNumpySchema: 

745 """Convert to an `ArrowNumpySchema`. 

746 

747 Returns 

748 ------- 

749 arrow_numpy_schema : `ArrowNumpySchema` 

750 Converted arrow numpy schema. 

751 """ 

752 return ArrowNumpySchema.from_arrow(astropy_to_arrow(self._schema).schema) 

753 

754 @property 

755 def schema(self) -> atable.Table: 

756 return self._schema 

757 

758 def __repr__(self) -> str: 

759 return repr(self._schema) 

760 

761 def __eq__(self, other: object) -> bool: 

762 if not isinstance(other, ArrowAstropySchema): 

763 return NotImplemented 

764 

765 # If this comparison passes then the two tables have the 

766 # same column names. 

767 if self._schema.dtype != other._schema.dtype: 

768 return False 

769 

770 for name in self._schema.columns: 

771 if not self._schema[name].unit == other._schema[name].unit: 

772 return False 

773 if not self._schema[name].description == other._schema[name].description: 

774 return False 

775 if not self._schema[name].format == other._schema[name].format: 

776 return False 

777 

778 return True 

779 

780 

781class ArrowNumpySchema: 

782 """Wrapper class for a schema for a numpy ndarray. 

783 

784 Parameters 

785 ---------- 

786 numpy_dtype : `numpy.dtype` 

787 Numpy dtype to convert. 

788 """ 

789 

790 def __init__(self, numpy_dtype: np.dtype) -> None: 

791 self._dtype = numpy_dtype 

792 

793 @classmethod 

794 def from_arrow(cls, schema: pa.Schema) -> ArrowNumpySchema: 

795 """Convert an arrow schema into an `ArrowNumpySchema`. 

796 

797 Parameters 

798 ---------- 

799 schema : `pyarrow.Schema` 

800 Pyarrow schema to convert. 

801 

802 Returns 

803 ------- 

804 numpy_schema : `ArrowNumpySchema` 

805 Converted arrow numpy schema. 

806 """ 

807 import numpy as np 

808 

809 dtype = _schema_to_dtype_list(schema) 

810 

811 return cls(np.dtype(dtype)) 

812 

813 def to_arrow_astropy_schema(self) -> ArrowAstropySchema: 

814 """Convert to an `ArrowAstropySchema`. 

815 

816 Returns 

817 ------- 

818 astropy_schema : `ArrowAstropySchema` 

819 Converted arrow astropy schema. 

820 """ 

821 import numpy as np 

822 

823 return ArrowAstropySchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema) 

824 

825 def to_dataframe_schema(self) -> DataFrameSchema: 

826 """Convert to a `DataFrameSchema`. 

827 

828 Returns 

829 ------- 

830 dataframe_schema : `DataFrameSchema` 

831 Converted dataframe schema. 

832 """ 

833 import numpy as np 

834 

835 return DataFrameSchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema) 

836 

837 def to_arrow_schema(self) -> pa.Schema: 

838 """Convert to a `pyarrow.Schema`. 

839 

840 Returns 

841 ------- 

842 arrow_schema : `pyarrow.Schema` 

843 Converted pyarrow schema. 

844 """ 

845 import numpy as np 

846 

847 return numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema 

848 

849 @property 

850 def schema(self) -> np.dtype: 

851 return self._dtype 

852 

853 def __repr__(self) -> str: 

854 return repr(self._dtype) 

855 

856 def __eq__(self, other: object) -> bool: 

857 if not isinstance(other, ArrowNumpySchema): 

858 return NotImplemented 

859 

860 if not self._dtype == other._dtype: 

861 return False 

862 

863 return True 

864 

865 

866def _split_multi_index_column_names(n: int, names: Iterable[str]) -> List[Sequence[str]]: 

867 """Split a string that represents a multi-index column. 

868 

869 PyArrow maps Pandas' multi-index column names (which are tuples in Python) 

870 to flat strings on disk. This routine exists to reconstruct the original 

871 tuple. 

872 

873 Parameters 

874 ---------- 

875 n : `int` 

876 Number of levels in the `pandas.MultiIndex` that is being 

877 reconstructed. 

878 names : `~collections.abc.Iterable` [`str`] 

879 Strings to be split. 

880 

881 Returns 

882 ------- 

883 column_names : `list` [`tuple` [`str`]] 

884 A list of multi-index column name tuples. 

885 """ 

886 column_names: List[Sequence[str]] = [] 

887 

888 pattern = re.compile(r"\({}\)".format(", ".join(["'(.*)'"] * n))) 

889 for name in names: 

890 m = re.search(pattern, name) 

891 if m is not None: 

892 column_names.append(m.groups()) 

893 

894 return column_names 

895 

896 

897def _standardize_multi_index_columns( 

898 schema: pa.Schema, columns: Union[List[tuple], dict[str, Union[str, List[str]]]] 

899) -> List[str]: 

900 """Transform a dictionary/iterable index from a multi-index column list 

901 into a string directly understandable by PyArrow. 

902 

903 Parameters 

904 ---------- 

905 schema : `pyarrow.Schema` 

906 Pyarrow schema. 

907 columns : `list` [`tuple`] or `dict` [`str`, `str` or `list` [`str`]] 

908 Columns to standardize. 

909 

910 Returns 

911 ------- 

912 names : `list` [`str`] 

913 Stringified representation of a multi-index column name. 

914 """ 

915 pd_index = arrow_schema_to_pandas_index(schema) 

916 index_level_names = tuple(pd_index.names) 

917 

918 names = [] 

919 

920 if isinstance(columns, list): 

921 for requested in columns: 

922 if not isinstance(requested, tuple): 

923 raise ValueError( 

924 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. " 

925 f"Instead got a {get_full_type_name(requested)}." 

926 ) 

927 names.append(str(requested)) 

928 else: 

929 if not isinstance(columns, collections.abc.Mapping): 

930 raise ValueError( 

931 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. " 

932 f"Instead got a {get_full_type_name(columns)}." 

933 ) 

934 if not set(index_level_names).issuperset(columns.keys()): 

935 raise ValueError( 

936 f"Cannot use dict with keys {set(columns.keys())} to select columns from {index_level_names}." 

937 ) 

938 factors = [ 

939 ensure_iterable(columns.get(level, pd_index.levels[i])) 

940 for i, level in enumerate(index_level_names) 

941 ] 

942 for requested in itertools.product(*factors): 

943 for i, value in enumerate(requested): 

944 if value not in pd_index.levels[i]: 

945 raise ValueError(f"Unrecognized value {value!r} for index {index_level_names[i]!r}.") 

946 names.append(str(requested)) 

947 

948 return names 

949 

950 

951def _apply_astropy_metadata(astropy_table: atable.Table, metadata: dict) -> None: 

952 """Apply any astropy metadata from the schema metadata. 

953 

954 Parameters 

955 ---------- 

956 astropy_table : `astropy.table.Table` 

957 Table to apply metadata. 

958 metadata : `dict` [`bytes`] 

959 Metadata dict. 

960 """ 

961 from astropy.table import meta 

962 

963 meta_yaml = metadata.get(b"table_meta_yaml", None) 

964 if meta_yaml: 

965 meta_yaml = meta_yaml.decode("UTF8").split("\n") 

966 meta_hdr = meta.get_header_from_yaml(meta_yaml) 

967 

968 # Set description, format, unit, meta from the column 

969 # metadata that was serialized with the table. 

970 header_cols = {x["name"]: x for x in meta_hdr["datatype"]} 

971 for col in astropy_table.columns.values(): 

972 for attr in ("description", "format", "unit", "meta"): 

973 if attr in header_cols[col.name]: 

974 setattr(col, attr, header_cols[col.name][attr]) 

975 

976 if "meta" in meta_hdr: 

977 astropy_table.meta.update(meta_hdr["meta"]) 

978 

979 

980def _arrow_string_to_numpy_dtype( 

981 schema: pa.Schema, name: str, numpy_column: np.ndarray | None = None, default_length: int = 10 

982) -> str: 

983 """Get the numpy dtype string associated with an arrow column. 

984 

985 Parameters 

986 ---------- 

987 schema : `pyarrow.Schema` 

988 Arrow table schema. 

989 name : `str` 

990 Column name. 

991 numpy_column : `numpy.ndarray`, optional 

992 Column to determine numpy string dtype. 

993 default_length : `int`, optional 

994 Default string length when not in metadata or can be inferred 

995 from column. 

996 

997 Returns 

998 ------- 

999 dtype_str : `str` 

1000 Numpy dtype string. 

1001 """ 

1002 # Special-case for string and binary columns 

1003 md_name = f"lsst::arrow::len::{name}" 

1004 strlen = default_length 

1005 metadata = schema.metadata if schema.metadata is not None else {} 

1006 if (encoded := md_name.encode("UTF-8")) in metadata: 

1007 # String/bytes length from header. 

1008 strlen = int(schema.metadata[encoded]) 

1009 elif numpy_column is not None: 

1010 if len(numpy_column) > 0: 

1011 strlen = max(len(row) for row in numpy_column) 

1012 

1013 dtype = f"U{strlen}" if schema.field(name).type == pa.string() else f"|S{strlen}" 

1014 

1015 return dtype 

1016 

1017 

1018def _append_numpy_string_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None: 

1019 """Append numpy string length keys to arrow metadata. 

1020 

1021 All column types are handled, but the metadata is only modified for 

1022 string and byte columns. 

1023 

1024 Parameters 

1025 ---------- 

1026 metadata : `dict` [`bytes`, `str`] 

1027 Metadata dictionary; modified in place. 

1028 name : `str` 

1029 Column name. 

1030 dtype : `np.dtype` 

1031 Numpy dtype. 

1032 """ 

1033 import numpy as np 

1034 

1035 if dtype.type is np.str_: 

1036 metadata[f"lsst::arrow::len::{name}".encode("UTF-8")] = str(dtype.itemsize // 4) 

1037 elif dtype.type is np.bytes_: 

1038 metadata[f"lsst::arrow::len::{name}".encode("UTF-8")] = str(dtype.itemsize) 

1039 

1040 

1041def _append_numpy_multidim_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None: 

1042 """Append numpy multi-dimensional shapes to arrow metadata. 

1043 

1044 All column types are handled, but the metadata is only modified for 

1045 multi-dimensional columns. 

1046 

1047 Parameters 

1048 ---------- 

1049 metadata : `dict` [`bytes`, `str`] 

1050 Metadata dictionary; modified in place. 

1051 name : `str` 

1052 Column name. 

1053 dtype : `np.dtype` 

1054 Numpy dtype. 

1055 """ 

1056 if len(dtype.shape) > 1: 

1057 metadata[f"lsst::arrow::shape::{name}".encode("UTF-8")] = str(dtype.shape) 

1058 

1059 

1060def _multidim_shape_from_metadata(metadata: dict[bytes, bytes], list_size: int, name: str) -> tuple[int, ...]: 

1061 """Retrieve the shape from the metadata, if available. 

1062 

1063 Parameters 

1064 ---------- 

1065 metadata : `dict` [`bytes`, `bytes`] 

1066 Metadata dictionary. 

1067 list_size : `int` 

1068 Size of the list datatype. 

1069 name : `str` 

1070 Column name. 

1071 

1072 Returns 

1073 ------- 

1074 shape : `tuple` [`int`] 

1075 Shape associated with the column. 

1076 

1077 Raises 

1078 ------ 

1079 RuntimeError 

1080 Raised if metadata is found but has incorrect format. 

1081 """ 

1082 md_name = f"lsst::arrow::shape::{name}" 

1083 if (encoded := md_name.encode("UTF-8")) in metadata: 

1084 groups = re.search(r"\((.*)\)", metadata[encoded].decode("UTF-8")) 

1085 if groups is None: 

1086 raise RuntimeError("Illegal value found in metadata.") 

1087 shape = tuple(int(x) for x in groups[1].split(",") if x != "") 

1088 else: 

1089 shape = (list_size,) 

1090 

1091 return shape 

1092 

1093 

1094def _schema_to_dtype_list(schema: pa.Schema) -> list[tuple[str, tuple[Any] | str]]: 

1095 """Convert a pyarrow schema to a numpy dtype. 

1096 

1097 Parameters 

1098 ---------- 

1099 schema : `pyarrow.Schema` 

1100 Input pyarrow schema. 

1101 

1102 Returns 

1103 ------- 

1104 dtype_list: `list` [`tuple`] 

1105 A list with name, type pairs. 

1106 """ 

1107 metadata = schema.metadata if schema.metadata is not None else {} 

1108 

1109 dtype: list[Any] = [] 

1110 for name in schema.names: 

1111 t = schema.field(name).type 

1112 if isinstance(t, pa.FixedSizeListType): 

1113 shape = _multidim_shape_from_metadata(metadata, t.list_size, name) 

1114 dtype.append((name, (t.value_type.to_pandas_dtype(), shape))) 

1115 elif t not in (pa.string(), pa.binary()): 

1116 dtype.append((name, t.to_pandas_dtype())) 

1117 else: 

1118 dtype.append((name, _arrow_string_to_numpy_dtype(schema, name))) 

1119 

1120 return dtype 

1121 

1122 

1123def _numpy_dtype_to_arrow_types(dtype: np.dtype) -> list[Any]: 

1124 """Convert a numpy dtype to a list of arrow types. 

1125 

1126 Parameters 

1127 ---------- 

1128 dtype : `numpy.dtype` 

1129 Numpy dtype to convert. 

1130 

1131 Returns 

1132 ------- 

1133 type_list : `list` [`object`] 

1134 Converted list of arrow types. 

1135 """ 

1136 from math import prod 

1137 

1138 import numpy as np 

1139 

1140 type_list: list[Any] = [] 

1141 if dtype.names is None: 

1142 return type_list 

1143 

1144 for name in dtype.names: 

1145 dt = dtype[name] 

1146 arrow_type: Any 

1147 if len(dt.shape) > 0: 

1148 arrow_type = pa.list_( 

1149 pa.from_numpy_dtype(cast(tuple[np.dtype, tuple[int, ...]], dt.subdtype)[0].type), 

1150 prod(dt.shape), 

1151 ) 

1152 else: 

1153 arrow_type = pa.from_numpy_dtype(dt.type) 

1154 type_list.append((name, arrow_type)) 

1155 

1156 return type_list 

1157 

1158 

1159def _numpy_dict_to_dtype(numpy_dict: dict[str, np.ndarray]) -> tuple[np.dtype, int]: 

1160 """Extract equivalent table dtype from dict of numpy arrays. 

1161 

1162 Parameters 

1163 ---------- 

1164 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

1165 Dict with keys as the column names, values as the arrays. 

1166 

1167 Returns 

1168 ------- 

1169 dtype : `numpy.dtype` 

1170 dtype of equivalent table. 

1171 rowcount : `int` 

1172 Number of rows in the table. 

1173 

1174 Raises 

1175 ------ 

1176 ValueError if columns in numpy_dict have unequal numbers of rows. 

1177 """ 

1178 import numpy as np 

1179 

1180 dtype_list = [] 

1181 rowcount = 0 

1182 for name, col in numpy_dict.items(): 

1183 if rowcount == 0: 

1184 rowcount = len(col) 

1185 if len(col) != rowcount: 

1186 raise ValueError(f"Column {name} has a different number of rows.") 

1187 if len(col.shape) == 1: 

1188 dtype_list.append((name, col.dtype)) 

1189 else: 

1190 dtype_list.append((name, (col.dtype, col.shape[1:]))) 

1191 dtype = np.dtype(dtype_list) 

1192 

1193 return (dtype, rowcount) 

1194 

1195 

1196def _numpy_style_arrays_to_arrow_arrays( 

1197 dtype: np.dtype, 

1198 rowcount: int, 

1199 np_style_arrays: dict[str, np.ndarray] | np.ndarray | atable.Table, 

1200 schema: pa.Schema, 

1201) -> list[pa.Array]: 

1202 """Convert numpy-style arrays to arrow arrays. 

1203 

1204 Parameters 

1205 ---------- 

1206 dtype : `numpy.dtype` 

1207 Numpy dtype of input table/arrays. 

1208 rowcount : `int` 

1209 Number of rows in input table/arrays. 

1210 np_style_arrays : `dict` [`str`, `np.ndarray`] or `np.ndarray` 

1211 or `astropy.table.Table` 

1212 Arrays to convert to arrow. 

1213 schema : `pyarrow.Schema` 

1214 Schema of arrow table. 

1215 

1216 Returns 

1217 ------- 

1218 arrow_arrays : `list` [`pyarrow.Array`] 

1219 List of converted pyarrow arrays. 

1220 """ 

1221 import numpy as np 

1222 

1223 arrow_arrays: list[pa.Array] = [] 

1224 if dtype.names is None: 

1225 return arrow_arrays 

1226 

1227 for name in dtype.names: 

1228 dt = dtype[name] 

1229 val: Any 

1230 if len(dt.shape) > 0: 

1231 if rowcount > 0: 

1232 val = np.split(np_style_arrays[name].ravel(), rowcount) 

1233 else: 

1234 val = [] 

1235 else: 

1236 val = np_style_arrays[name] 

1237 

1238 try: 

1239 arrow_arrays.append(pa.array(val, type=schema.field(name).type)) 

1240 except pa.ArrowNotImplementedError as err: 

1241 # Check if val is big-endian. 

1242 if (np.little_endian and val.dtype.byteorder == ">") or ( 

1243 not np.little_endian and val.dtype.byteorder == "=" 

1244 ): 

1245 # We need to convert the array to little-endian. 

1246 val2 = val.byteswap() 

1247 val2.dtype = val2.dtype.newbyteorder("<") 

1248 arrow_arrays.append(pa.array(val2, type=schema.field(name).type)) 

1249 else: 

1250 # This failed for some other reason so raise the exception. 

1251 raise err 

1252 

1253 return arrow_arrays