Coverage for python/lsst/daf/butler/formatters/parquet.py: 14%

442 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-08-05 01:26 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ( 

25 "ParquetFormatter", 

26 "arrow_to_pandas", 

27 "arrow_to_astropy", 

28 "arrow_to_numpy", 

29 "arrow_to_numpy_dict", 

30 "pandas_to_arrow", 

31 "pandas_to_astropy", 

32 "astropy_to_arrow", 

33 "numpy_to_arrow", 

34 "numpy_to_astropy", 

35 "numpy_dict_to_arrow", 

36 "arrow_schema_to_pandas_index", 

37 "DataFrameSchema", 

38 "ArrowAstropySchema", 

39 "ArrowNumpySchema", 

40 "compute_row_group_size", 

41) 

42 

43import collections.abc 

44import itertools 

45import json 

46import re 

47from collections.abc import Iterable, Sequence 

48from typing import TYPE_CHECKING, Any, cast 

49 

50import pyarrow as pa 

51import pyarrow.parquet as pq 

52from lsst.daf.butler import Formatter 

53from lsst.utils.introspection import get_full_type_name 

54from lsst.utils.iteration import ensure_iterable 

55 

56if TYPE_CHECKING: 

57 import astropy.table as atable 

58 import numpy as np 

59 import pandas as pd 

60 

61TARGET_ROW_GROUP_BYTES = 1_000_000_000 

62 

63 

64class ParquetFormatter(Formatter): 

65 """Interface for reading and writing Arrow Table objects to and from 

66 Parquet files. 

67 """ 

68 

69 extension = ".parq" 

70 

71 def read(self, component: str | None = None) -> Any: 

72 # Docstring inherited from Formatter.read. 

73 schema = pq.read_schema(self.fileDescriptor.location.path) 

74 

75 if component in ("columns", "schema"): 

76 # The schema will be translated to column format 

77 # depending on the input type. 

78 return schema 

79 elif component == "rowcount": 

80 # Get the rowcount from the metadata if possible, otherwise count. 

81 if b"lsst::arrow::rowcount" in schema.metadata: 

82 return int(schema.metadata[b"lsst::arrow::rowcount"]) 

83 

84 temp_table = pq.read_table( 

85 self.fileDescriptor.location.path, 

86 columns=[schema.names[0]], 

87 use_threads=False, 

88 use_pandas_metadata=False, 

89 ) 

90 

91 return len(temp_table[schema.names[0]]) 

92 

93 par_columns = None 

94 if self.fileDescriptor.parameters: 

95 par_columns = self.fileDescriptor.parameters.pop("columns", None) 

96 if par_columns: 

97 has_pandas_multi_index = False 

98 if b"pandas" in schema.metadata: 

99 md = json.loads(schema.metadata[b"pandas"]) 

100 if len(md["column_indexes"]) > 1: 

101 has_pandas_multi_index = True 

102 

103 if not has_pandas_multi_index: 

104 # Ensure uniqueness, keeping order. 

105 par_columns = list(dict.fromkeys(ensure_iterable(par_columns))) 

106 file_columns = [name for name in schema.names if not name.startswith("__")] 

107 

108 for par_column in par_columns: 

109 if par_column not in file_columns: 

110 raise ValueError( 

111 f"Column {par_column} specified in parameters not available in parquet file." 

112 ) 

113 else: 

114 par_columns = _standardize_multi_index_columns( 

115 arrow_schema_to_pandas_index(schema), 

116 par_columns, 

117 ) 

118 

119 if len(self.fileDescriptor.parameters): 

120 raise ValueError( 

121 f"Unsupported parameters {self.fileDescriptor.parameters} in ArrowTable read." 

122 ) 

123 

124 metadata = schema.metadata if schema.metadata is not None else {} 

125 arrow_table = pq.read_table( 

126 self.fileDescriptor.location.path, 

127 columns=par_columns, 

128 use_threads=False, 

129 use_pandas_metadata=(b"pandas" in metadata), 

130 ) 

131 

132 return arrow_table 

133 

134 def write(self, inMemoryDataset: Any) -> None: 

135 import numpy as np 

136 from astropy.table import Table as astropyTable 

137 

138 arrow_table = None 

139 if isinstance(inMemoryDataset, pa.Table): 

140 # This will be the most likely match. 

141 arrow_table = inMemoryDataset 

142 elif isinstance(inMemoryDataset, astropyTable): 

143 arrow_table = astropy_to_arrow(inMemoryDataset) 

144 elif isinstance(inMemoryDataset, np.ndarray): 

145 arrow_table = numpy_to_arrow(inMemoryDataset) 

146 elif isinstance(inMemoryDataset, dict): 

147 try: 

148 arrow_table = numpy_dict_to_arrow(inMemoryDataset) 

149 except (TypeError, AttributeError) as e: 

150 raise ValueError( 

151 "Input dict for inMemoryDataset does not appear to be a dict of numpy arrays." 

152 ) from e 

153 else: 

154 if hasattr(inMemoryDataset, "to_parquet"): 

155 # This may be a pandas DataFrame 

156 try: 

157 import pandas as pd 

158 except ImportError: 

159 pd = None 

160 

161 if pd is not None and isinstance(inMemoryDataset, pd.DataFrame): 

162 arrow_table = pandas_to_arrow(inMemoryDataset) 

163 

164 if arrow_table is None: 

165 raise ValueError( 

166 f"Unsupported type {get_full_type_name(inMemoryDataset)} of " 

167 "inMemoryDataset for ParquetFormatter." 

168 ) 

169 

170 row_group_size = compute_row_group_size(arrow_table.schema) 

171 

172 location = self.makeUpdatedLocation(self.fileDescriptor.location) 

173 

174 pq.write_table(arrow_table, location.path, row_group_size=row_group_size) 

175 

176 

177def arrow_to_pandas(arrow_table: pa.Table) -> pd.DataFrame: 

178 """Convert a pyarrow table to a pandas DataFrame. 

179 

180 Parameters 

181 ---------- 

182 arrow_table : `pyarrow.Table` 

183 Input arrow table to convert. If the table has ``pandas`` metadata 

184 in the schema it will be used in the construction of the 

185 ``DataFrame``. 

186 

187 Returns 

188 ------- 

189 dataframe : `pandas.DataFrame` 

190 Converted pandas dataframe. 

191 """ 

192 return arrow_table.to_pandas(use_threads=False, integer_object_nulls=True) 

193 

194 

195def arrow_to_astropy(arrow_table: pa.Table) -> atable.Table: 

196 """Convert a pyarrow table to an `astropy.Table`. 

197 

198 Parameters 

199 ---------- 

200 arrow_table : `pyarrow.Table` 

201 Input arrow table to convert. If the table has astropy unit 

202 metadata in the schema it will be used in the construction 

203 of the ``astropy.Table``. 

204 

205 Returns 

206 ------- 

207 table : `astropy.Table` 

208 Converted astropy table. 

209 """ 

210 from astropy.table import Table 

211 

212 astropy_table = Table(arrow_to_numpy_dict(arrow_table)) 

213 

214 metadata = arrow_table.schema.metadata if arrow_table.schema.metadata is not None else {} 

215 

216 _apply_astropy_metadata(astropy_table, metadata) 

217 

218 return astropy_table 

219 

220 

221def arrow_to_numpy(arrow_table: pa.Table) -> np.ndarray: 

222 """Convert a pyarrow table to a structured numpy array. 

223 

224 Parameters 

225 ---------- 

226 arrow_table : `pyarrow.Table` 

227 Input arrow table. 

228 

229 Returns 

230 ------- 

231 array : `numpy.ndarray` (N,) 

232 Numpy array table with N rows and the same column names 

233 as the input arrow table. 

234 """ 

235 import numpy as np 

236 

237 numpy_dict = arrow_to_numpy_dict(arrow_table) 

238 

239 dtype = [] 

240 for name, col in numpy_dict.items(): 

241 if len(shape := numpy_dict[name].shape) <= 1: 

242 dtype.append((name, col.dtype)) 

243 else: 

244 dtype.append((name, (col.dtype, shape[1:]))) 

245 

246 array = np.rec.fromarrays(numpy_dict.values(), dtype=dtype) 

247 

248 return array 

249 

250 

251def arrow_to_numpy_dict(arrow_table: pa.Table) -> dict[str, np.ndarray]: 

252 """Convert a pyarrow table to a dict of numpy arrays. 

253 

254 Parameters 

255 ---------- 

256 arrow_table : `pyarrow.Table` 

257 Input arrow table. 

258 

259 Returns 

260 ------- 

261 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

262 Dict with keys as the column names, values as the arrays. 

263 """ 

264 import numpy as np 

265 

266 schema = arrow_table.schema 

267 metadata = schema.metadata if schema.metadata is not None else {} 

268 

269 numpy_dict = {} 

270 

271 for name in schema.names: 

272 t = schema.field(name).type 

273 

274 if arrow_table[name].null_count == 0: 

275 # Regular non-masked column 

276 col = arrow_table[name].to_numpy() 

277 else: 

278 # For a masked column, we need to ask arrow to fill the null 

279 # values with an appropriately typed value before conversion. 

280 # Then we apply the mask to get a masked array of the correct type. 

281 

282 if t in (pa.string(), pa.binary()): 

283 dummy = "" 

284 else: 

285 dummy = t.to_pandas_dtype()(0) 

286 

287 col = np.ma.masked_array( 

288 data=arrow_table[name].fill_null(dummy).to_numpy(), 

289 mask=arrow_table[name].is_null().to_numpy(), 

290 ) 

291 

292 if t in (pa.string(), pa.binary()): 

293 col = col.astype(_arrow_string_to_numpy_dtype(schema, name, col)) 

294 elif isinstance(t, pa.FixedSizeListType): 

295 if len(col) > 0: 

296 col = np.stack(col) 

297 else: 

298 # this is an empty column, and needs to be coerced to type. 

299 col = col.astype(t.value_type.to_pandas_dtype()) 

300 

301 shape = _multidim_shape_from_metadata(metadata, t.list_size, name) 

302 col = col.reshape((len(arrow_table), *shape)) 

303 

304 numpy_dict[name] = col 

305 

306 return numpy_dict 

307 

308 

309def _numpy_dict_to_numpy(numpy_dict: dict[str, np.ndarray]) -> np.ndarray: 

310 """Convert a dict of numpy arrays to a structured numpy array. 

311 

312 Parameters 

313 ---------- 

314 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

315 Dict with keys as the column names, values as the arrays. 

316 

317 Returns 

318 ------- 

319 array : `numpy.ndarray` (N,) 

320 Numpy array table with N rows and columns names from the dict keys. 

321 """ 

322 return arrow_to_numpy(numpy_dict_to_arrow(numpy_dict)) 

323 

324 

325def _numpy_to_numpy_dict(np_array: np.ndarray) -> dict[str, np.ndarray]: 

326 """Convert a structured numpy array to a dict of numpy arrays. 

327 

328 Parameters 

329 ---------- 

330 np_array : `numpy.ndarray` 

331 Input numpy array with multiple fields. 

332 

333 Returns 

334 ------- 

335 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

336 Dict with keys as the column names, values as the arrays. 

337 """ 

338 return arrow_to_numpy_dict(numpy_to_arrow(np_array)) 

339 

340 

341def numpy_to_arrow(np_array: np.ndarray) -> pa.Table: 

342 """Convert a numpy array table to an arrow table. 

343 

344 Parameters 

345 ---------- 

346 np_array : `numpy.ndarray` 

347 Input numpy array with multiple fields. 

348 

349 Returns 

350 ------- 

351 arrow_table : `pyarrow.Table` 

352 Converted arrow table. 

353 """ 

354 type_list = _numpy_dtype_to_arrow_types(np_array.dtype) 

355 

356 md = {} 

357 md[b"lsst::arrow::rowcount"] = str(len(np_array)) 

358 

359 for name in np_array.dtype.names: 

360 _append_numpy_string_metadata(md, name, np_array.dtype[name]) 

361 _append_numpy_multidim_metadata(md, name, np_array.dtype[name]) 

362 

363 schema = pa.schema(type_list, metadata=md) 

364 

365 arrays = _numpy_style_arrays_to_arrow_arrays( 

366 np_array.dtype, 

367 len(np_array), 

368 np_array, 

369 schema, 

370 ) 

371 

372 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

373 

374 return arrow_table 

375 

376 

377def numpy_dict_to_arrow(numpy_dict: dict[str, np.ndarray]) -> pa.Table: 

378 """Convert a dict of numpy arrays to an arrow table. 

379 

380 Parameters 

381 ---------- 

382 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

383 Dict with keys as the column names, values as the arrays. 

384 

385 Returns 

386 ------- 

387 arrow_table : `pyarrow.Table` 

388 Converted arrow table. 

389 

390 Raises 

391 ------ 

392 ValueError if columns in numpy_dict have unequal numbers of rows. 

393 """ 

394 dtype, rowcount = _numpy_dict_to_dtype(numpy_dict) 

395 type_list = _numpy_dtype_to_arrow_types(dtype) 

396 

397 md = {} 

398 md[b"lsst::arrow::rowcount"] = str(rowcount) 

399 

400 if dtype.names is not None: 

401 for name in dtype.names: 

402 _append_numpy_string_metadata(md, name, dtype[name]) 

403 _append_numpy_multidim_metadata(md, name, dtype[name]) 

404 

405 schema = pa.schema(type_list, metadata=md) 

406 

407 arrays = _numpy_style_arrays_to_arrow_arrays( 

408 dtype, 

409 rowcount, 

410 numpy_dict, 

411 schema, 

412 ) 

413 

414 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

415 

416 return arrow_table 

417 

418 

419def astropy_to_arrow(astropy_table: atable.Table) -> pa.Table: 

420 """Convert an astropy table to an arrow table. 

421 

422 Parameters 

423 ---------- 

424 astropy_table : `astropy.Table` 

425 Input astropy table. 

426 

427 Returns 

428 ------- 

429 arrow_table : `pyarrow.Table` 

430 Converted arrow table. 

431 """ 

432 from astropy.table import meta 

433 

434 type_list = _numpy_dtype_to_arrow_types(astropy_table.dtype) 

435 

436 md = {} 

437 md[b"lsst::arrow::rowcount"] = str(len(astropy_table)) 

438 

439 for name in astropy_table.dtype.names: 

440 _append_numpy_string_metadata(md, name, astropy_table.dtype[name]) 

441 _append_numpy_multidim_metadata(md, name, astropy_table.dtype[name]) 

442 

443 meta_yaml = meta.get_yaml_from_table(astropy_table) 

444 meta_yaml_str = "\n".join(meta_yaml) 

445 md[b"table_meta_yaml"] = meta_yaml_str 

446 

447 schema = pa.schema(type_list, metadata=md) 

448 

449 arrays = _numpy_style_arrays_to_arrow_arrays( 

450 astropy_table.dtype, 

451 len(astropy_table), 

452 astropy_table, 

453 schema, 

454 ) 

455 

456 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

457 

458 return arrow_table 

459 

460 

461def _astropy_to_numpy_dict(astropy_table: atable.Table) -> dict[str, np.ndarray]: 

462 """Convert an astropy table to an arrow table. 

463 

464 Parameters 

465 ---------- 

466 astropy_table : `astropy.Table` 

467 Input astropy table. 

468 

469 Returns 

470 ------- 

471 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

472 Dict with keys as the column names, values as the arrays. 

473 """ 

474 return arrow_to_numpy_dict(astropy_to_arrow(astropy_table)) 

475 

476 

477def pandas_to_arrow(dataframe: pd.DataFrame, default_length: int = 10) -> pa.Table: 

478 """Convert a pandas dataframe to an arrow table. 

479 

480 Parameters 

481 ---------- 

482 dataframe : `pandas.DataFrame` 

483 Input pandas dataframe. 

484 default_length : `int`, optional 

485 Default string length when not in metadata or can be inferred 

486 from column. 

487 

488 Returns 

489 ------- 

490 arrow_table : `pyarrow.Table` 

491 Converted arrow table. 

492 """ 

493 arrow_table = pa.Table.from_pandas(dataframe) 

494 

495 # Update the metadata 

496 md = arrow_table.schema.metadata 

497 

498 md[b"lsst::arrow::rowcount"] = str(arrow_table.num_rows) 

499 

500 # We loop through the arrow table columns because the datatypes have 

501 # been checked and converted from pandas objects. 

502 for name in arrow_table.column_names: 

503 if not name.startswith("__") and arrow_table[name].type == pa.string(): 

504 if len(arrow_table[name]) > 0: 

505 strlen = max(len(row.as_py()) for row in arrow_table[name] if row.is_valid) 

506 else: 

507 strlen = default_length 

508 md[f"lsst::arrow::len::{name}".encode()] = str(strlen) 

509 

510 arrow_table = arrow_table.replace_schema_metadata(md) 

511 

512 return arrow_table 

513 

514 

515def pandas_to_astropy(dataframe: pd.DataFrame) -> atable.Table: 

516 """Convert a pandas dataframe to an astropy table, preserving indexes. 

517 

518 Parameters 

519 ---------- 

520 dataframe : `pandas.DataFrame` 

521 Input pandas dataframe. 

522 

523 Returns 

524 ------- 

525 astropy_table : `astropy.table.Table` 

526 Converted astropy table. 

527 """ 

528 import pandas as pd 

529 from astropy.table import Table 

530 

531 if isinstance(dataframe.columns, pd.MultiIndex): 

532 raise ValueError("Cannot convert a multi-index dataframe to an astropy table.") 

533 

534 return Table.from_pandas(dataframe, index=True) 

535 

536 

537def _pandas_to_numpy_dict(dataframe: pd.DataFrame) -> dict[str, np.ndarray]: 

538 """Convert a pandas dataframe to an dict of numpy arrays. 

539 

540 Parameters 

541 ---------- 

542 dataframe : `pandas.DataFrame` 

543 Input pandas dataframe. 

544 

545 Returns 

546 ------- 

547 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

548 Dict with keys as the column names, values as the arrays. 

549 """ 

550 return arrow_to_numpy_dict(pandas_to_arrow(dataframe)) 

551 

552 

553def numpy_to_astropy(np_array: np.ndarray) -> atable.Table: 

554 """Convert a numpy table to an astropy table. 

555 

556 Parameters 

557 ---------- 

558 np_array : `numpy.ndarray` 

559 Input numpy array with multiple fields. 

560 

561 Returns 

562 ------- 

563 astropy_table : `astropy.table.Table` 

564 Converted astropy table. 

565 """ 

566 from astropy.table import Table 

567 

568 return Table(data=np_array, copy=False) 

569 

570 

571def arrow_schema_to_pandas_index(schema: pa.Schema) -> pd.Index | pd.MultiIndex: 

572 """Convert an arrow schema to a pandas index/multiindex. 

573 

574 Parameters 

575 ---------- 

576 schema : `pyarrow.Schema` 

577 Input pyarrow schema. 

578 

579 Returns 

580 ------- 

581 index : `pandas.Index` or `pandas.MultiIndex` 

582 Converted pandas index. 

583 """ 

584 import pandas as pd 

585 

586 if b"pandas" in schema.metadata: 

587 md = json.loads(schema.metadata[b"pandas"]) 

588 indexes = md["column_indexes"] 

589 len_indexes = len(indexes) 

590 else: 

591 len_indexes = 0 

592 

593 if len_indexes <= 1: 

594 return pd.Index(name for name in schema.names if not name.startswith("__")) 

595 else: 

596 raw_columns = _split_multi_index_column_names(len(indexes), schema.names) 

597 return pd.MultiIndex.from_tuples(raw_columns, names=[f["name"] for f in indexes]) 

598 

599 

600def arrow_schema_to_column_list(schema: pa.Schema) -> list[str]: 

601 """Convert an arrow schema to a list of string column names. 

602 

603 Parameters 

604 ---------- 

605 schema : `pyarrow.Schema` 

606 Input pyarrow schema. 

607 

608 Returns 

609 ------- 

610 column_list : `list` [`str`] 

611 Converted list of column names. 

612 """ 

613 return [name for name in schema.names] 

614 

615 

616class DataFrameSchema: 

617 """Wrapper class for a schema for a pandas DataFrame. 

618 

619 Parameters 

620 ---------- 

621 dataframe : `pandas.DataFrame` 

622 Dataframe to turn into a schema. 

623 """ 

624 

625 def __init__(self, dataframe: pd.DataFrame) -> None: 

626 self._schema = dataframe.loc[[False] * len(dataframe)] 

627 

628 @classmethod 

629 def from_arrow(cls, schema: pa.Schema) -> DataFrameSchema: 

630 """Convert an arrow schema into a `DataFrameSchema`. 

631 

632 Parameters 

633 ---------- 

634 schema : `pyarrow.Schema` 

635 The pyarrow schema to convert. 

636 

637 Returns 

638 ------- 

639 dataframe_schema : `DataFrameSchema` 

640 Converted dataframe schema. 

641 """ 

642 empty_table = pa.Table.from_pylist([] * len(schema.names), schema=schema) 

643 

644 return cls(empty_table.to_pandas()) 

645 

646 def to_arrow_schema(self) -> pa.Schema: 

647 """Convert to an arrow schema. 

648 

649 Returns 

650 ------- 

651 arrow_schema : `pyarrow.Schema` 

652 Converted pyarrow schema. 

653 """ 

654 arrow_table = pa.Table.from_pandas(self._schema) 

655 

656 return arrow_table.schema 

657 

658 def to_arrow_numpy_schema(self) -> ArrowNumpySchema: 

659 """Convert to an `ArrowNumpySchema`. 

660 

661 Returns 

662 ------- 

663 arrow_numpy_schema : `ArrowNumpySchema` 

664 Converted arrow numpy schema. 

665 """ 

666 return ArrowNumpySchema.from_arrow(self.to_arrow_schema()) 

667 

668 def to_arrow_astropy_schema(self) -> ArrowAstropySchema: 

669 """Convert to an ArrowAstropySchema. 

670 

671 Returns 

672 ------- 

673 arrow_astropy_schema : `ArrowAstropySchema` 

674 Converted arrow astropy schema. 

675 """ 

676 return ArrowAstropySchema.from_arrow(self.to_arrow_schema()) 

677 

678 @property 

679 def schema(self) -> np.dtype: 

680 return self._schema 

681 

682 def __repr__(self) -> str: 

683 return repr(self._schema) 

684 

685 def __eq__(self, other: object) -> bool: 

686 if not isinstance(other, DataFrameSchema): 

687 return NotImplemented 

688 

689 return self._schema.equals(other._schema) 

690 

691 

692class ArrowAstropySchema: 

693 """Wrapper class for a schema for an astropy table. 

694 

695 Parameters 

696 ---------- 

697 astropy_table : `astropy.table.Table` 

698 Input astropy table. 

699 """ 

700 

701 def __init__(self, astropy_table: atable.Table) -> None: 

702 self._schema = astropy_table[:0] 

703 

704 @classmethod 

705 def from_arrow(cls, schema: pa.Schema) -> ArrowAstropySchema: 

706 """Convert an arrow schema into a ArrowAstropySchema. 

707 

708 Parameters 

709 ---------- 

710 schema : `pyarrow.Schema` 

711 Input pyarrow schema. 

712 

713 Returns 

714 ------- 

715 astropy_schema : `ArrowAstropySchema` 

716 Converted arrow astropy schema. 

717 """ 

718 import numpy as np 

719 from astropy.table import Table 

720 

721 dtype = _schema_to_dtype_list(schema) 

722 

723 data = np.zeros(0, dtype=dtype) 

724 astropy_table = Table(data=data) 

725 

726 metadata = schema.metadata if schema.metadata is not None else {} 

727 

728 _apply_astropy_metadata(astropy_table, metadata) 

729 

730 return cls(astropy_table) 

731 

732 def to_arrow_schema(self) -> pa.Schema: 

733 """Convert to an arrow schema. 

734 

735 Returns 

736 ------- 

737 arrow_schema : `pyarrow.Schema` 

738 Converted pyarrow schema. 

739 """ 

740 return astropy_to_arrow(self._schema).schema 

741 

742 def to_dataframe_schema(self) -> DataFrameSchema: 

743 """Convert to a DataFrameSchema. 

744 

745 Returns 

746 ------- 

747 dataframe_schema : `DataFrameSchema` 

748 Converted dataframe schema. 

749 """ 

750 return DataFrameSchema.from_arrow(astropy_to_arrow(self._schema).schema) 

751 

752 def to_arrow_numpy_schema(self) -> ArrowNumpySchema: 

753 """Convert to an `ArrowNumpySchema`. 

754 

755 Returns 

756 ------- 

757 arrow_numpy_schema : `ArrowNumpySchema` 

758 Converted arrow numpy schema. 

759 """ 

760 return ArrowNumpySchema.from_arrow(astropy_to_arrow(self._schema).schema) 

761 

762 @property 

763 def schema(self) -> atable.Table: 

764 return self._schema 

765 

766 def __repr__(self) -> str: 

767 return repr(self._schema) 

768 

769 def __eq__(self, other: object) -> bool: 

770 if not isinstance(other, ArrowAstropySchema): 

771 return NotImplemented 

772 

773 # If this comparison passes then the two tables have the 

774 # same column names. 

775 if self._schema.dtype != other._schema.dtype: 

776 return False 

777 

778 for name in self._schema.columns: 

779 if not self._schema[name].unit == other._schema[name].unit: 

780 return False 

781 if not self._schema[name].description == other._schema[name].description: 

782 return False 

783 if not self._schema[name].format == other._schema[name].format: 

784 return False 

785 

786 return True 

787 

788 

789class ArrowNumpySchema: 

790 """Wrapper class for a schema for a numpy ndarray. 

791 

792 Parameters 

793 ---------- 

794 numpy_dtype : `numpy.dtype` 

795 Numpy dtype to convert. 

796 """ 

797 

798 def __init__(self, numpy_dtype: np.dtype) -> None: 

799 self._dtype = numpy_dtype 

800 

801 @classmethod 

802 def from_arrow(cls, schema: pa.Schema) -> ArrowNumpySchema: 

803 """Convert an arrow schema into an `ArrowNumpySchema`. 

804 

805 Parameters 

806 ---------- 

807 schema : `pyarrow.Schema` 

808 Pyarrow schema to convert. 

809 

810 Returns 

811 ------- 

812 numpy_schema : `ArrowNumpySchema` 

813 Converted arrow numpy schema. 

814 """ 

815 import numpy as np 

816 

817 dtype = _schema_to_dtype_list(schema) 

818 

819 return cls(np.dtype(dtype)) 

820 

821 def to_arrow_astropy_schema(self) -> ArrowAstropySchema: 

822 """Convert to an `ArrowAstropySchema`. 

823 

824 Returns 

825 ------- 

826 astropy_schema : `ArrowAstropySchema` 

827 Converted arrow astropy schema. 

828 """ 

829 import numpy as np 

830 

831 return ArrowAstropySchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema) 

832 

833 def to_dataframe_schema(self) -> DataFrameSchema: 

834 """Convert to a `DataFrameSchema`. 

835 

836 Returns 

837 ------- 

838 dataframe_schema : `DataFrameSchema` 

839 Converted dataframe schema. 

840 """ 

841 import numpy as np 

842 

843 return DataFrameSchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema) 

844 

845 def to_arrow_schema(self) -> pa.Schema: 

846 """Convert to a `pyarrow.Schema`. 

847 

848 Returns 

849 ------- 

850 arrow_schema : `pyarrow.Schema` 

851 Converted pyarrow schema. 

852 """ 

853 import numpy as np 

854 

855 return numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema 

856 

857 @property 

858 def schema(self) -> np.dtype: 

859 return self._dtype 

860 

861 def __repr__(self) -> str: 

862 return repr(self._dtype) 

863 

864 def __eq__(self, other: object) -> bool: 

865 if not isinstance(other, ArrowNumpySchema): 

866 return NotImplemented 

867 

868 if not self._dtype == other._dtype: 

869 return False 

870 

871 return True 

872 

873 

874def _split_multi_index_column_names(n: int, names: Iterable[str]) -> list[Sequence[str]]: 

875 """Split a string that represents a multi-index column. 

876 

877 PyArrow maps Pandas' multi-index column names (which are tuples in Python) 

878 to flat strings on disk. This routine exists to reconstruct the original 

879 tuple. 

880 

881 Parameters 

882 ---------- 

883 n : `int` 

884 Number of levels in the `pandas.MultiIndex` that is being 

885 reconstructed. 

886 names : `~collections.abc.Iterable` [`str`] 

887 Strings to be split. 

888 

889 Returns 

890 ------- 

891 column_names : `list` [`tuple` [`str`]] 

892 A list of multi-index column name tuples. 

893 """ 

894 column_names: list[Sequence[str]] = [] 

895 

896 pattern = re.compile(r"\({}\)".format(", ".join(["'(.*)'"] * n))) 

897 for name in names: 

898 m = re.search(pattern, name) 

899 if m is not None: 

900 column_names.append(m.groups()) 

901 

902 return column_names 

903 

904 

905def _standardize_multi_index_columns( 

906 pd_index: pd.MultiIndex, 

907 columns: Any, 

908 stringify: bool = True, 

909) -> list[str | Sequence[Any]]: 

910 """Transform a dictionary/iterable index from a multi-index column list 

911 into a string directly understandable by PyArrow. 

912 

913 Parameters 

914 ---------- 

915 pd_index : `pandas.MultiIndex` 

916 Pandas multi-index. 

917 columns : `list` [`tuple`] or `dict` [`str`, `str` or `list` [`str`]] 

918 Columns to standardize. 

919 stringify : `bool`, optional 

920 Should the column names be stringified? 

921 

922 Returns 

923 ------- 

924 names : `list` [`str`] 

925 Stringified representation of a multi-index column name. 

926 """ 

927 index_level_names = tuple(pd_index.names) 

928 

929 names: list[str | Sequence[Any]] = [] 

930 

931 if isinstance(columns, list): 

932 for requested in columns: 

933 if not isinstance(requested, tuple): 

934 raise ValueError( 

935 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. " 

936 f"Instead got a {get_full_type_name(requested)}." 

937 ) 

938 if stringify: 

939 names.append(str(requested)) 

940 else: 

941 names.append(requested) 

942 else: 

943 if not isinstance(columns, collections.abc.Mapping): 

944 raise ValueError( 

945 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. " 

946 f"Instead got a {get_full_type_name(columns)}." 

947 ) 

948 if not set(index_level_names).issuperset(columns.keys()): 

949 raise ValueError( 

950 f"Cannot use dict with keys {set(columns.keys())} to select columns from {index_level_names}." 

951 ) 

952 factors = [ 

953 ensure_iterable(columns.get(level, pd_index.levels[i])) 

954 for i, level in enumerate(index_level_names) 

955 ] 

956 for requested in itertools.product(*factors): 

957 for i, value in enumerate(requested): 

958 if value not in pd_index.levels[i]: 

959 raise ValueError(f"Unrecognized value {value!r} for index {index_level_names[i]!r}.") 

960 if stringify: 

961 names.append(str(requested)) 

962 else: 

963 names.append(requested) 

964 

965 return names 

966 

967 

968def _apply_astropy_metadata(astropy_table: atable.Table, metadata: dict) -> None: 

969 """Apply any astropy metadata from the schema metadata. 

970 

971 Parameters 

972 ---------- 

973 astropy_table : `astropy.table.Table` 

974 Table to apply metadata. 

975 metadata : `dict` [`bytes`] 

976 Metadata dict. 

977 """ 

978 from astropy.table import meta 

979 

980 meta_yaml = metadata.get(b"table_meta_yaml", None) 

981 if meta_yaml: 

982 meta_yaml = meta_yaml.decode("UTF8").split("\n") 

983 meta_hdr = meta.get_header_from_yaml(meta_yaml) 

984 

985 # Set description, format, unit, meta from the column 

986 # metadata that was serialized with the table. 

987 header_cols = {x["name"]: x for x in meta_hdr["datatype"]} 

988 for col in astropy_table.columns.values(): 

989 for attr in ("description", "format", "unit", "meta"): 

990 if attr in header_cols[col.name]: 

991 setattr(col, attr, header_cols[col.name][attr]) 

992 

993 if "meta" in meta_hdr: 

994 astropy_table.meta.update(meta_hdr["meta"]) 

995 

996 

997def _arrow_string_to_numpy_dtype( 

998 schema: pa.Schema, name: str, numpy_column: np.ndarray | None = None, default_length: int = 10 

999) -> str: 

1000 """Get the numpy dtype string associated with an arrow column. 

1001 

1002 Parameters 

1003 ---------- 

1004 schema : `pyarrow.Schema` 

1005 Arrow table schema. 

1006 name : `str` 

1007 Column name. 

1008 numpy_column : `numpy.ndarray`, optional 

1009 Column to determine numpy string dtype. 

1010 default_length : `int`, optional 

1011 Default string length when not in metadata or can be inferred 

1012 from column. 

1013 

1014 Returns 

1015 ------- 

1016 dtype_str : `str` 

1017 Numpy dtype string. 

1018 """ 

1019 # Special-case for string and binary columns 

1020 md_name = f"lsst::arrow::len::{name}" 

1021 strlen = default_length 

1022 metadata = schema.metadata if schema.metadata is not None else {} 

1023 if (encoded := md_name.encode("UTF-8")) in metadata: 

1024 # String/bytes length from header. 

1025 strlen = int(schema.metadata[encoded]) 

1026 elif numpy_column is not None and len(numpy_column) > 0: 

1027 strlen = max(len(row) for row in numpy_column) 

1028 

1029 dtype = f"U{strlen}" if schema.field(name).type == pa.string() else f"|S{strlen}" 

1030 

1031 return dtype 

1032 

1033 

1034def _append_numpy_string_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None: 

1035 """Append numpy string length keys to arrow metadata. 

1036 

1037 All column types are handled, but the metadata is only modified for 

1038 string and byte columns. 

1039 

1040 Parameters 

1041 ---------- 

1042 metadata : `dict` [`bytes`, `str`] 

1043 Metadata dictionary; modified in place. 

1044 name : `str` 

1045 Column name. 

1046 dtype : `np.dtype` 

1047 Numpy dtype. 

1048 """ 

1049 import numpy as np 

1050 

1051 if dtype.type is np.str_: 

1052 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize // 4) 

1053 elif dtype.type is np.bytes_: 

1054 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize) 

1055 

1056 

1057def _append_numpy_multidim_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None: 

1058 """Append numpy multi-dimensional shapes to arrow metadata. 

1059 

1060 All column types are handled, but the metadata is only modified for 

1061 multi-dimensional columns. 

1062 

1063 Parameters 

1064 ---------- 

1065 metadata : `dict` [`bytes`, `str`] 

1066 Metadata dictionary; modified in place. 

1067 name : `str` 

1068 Column name. 

1069 dtype : `np.dtype` 

1070 Numpy dtype. 

1071 """ 

1072 if len(dtype.shape) > 1: 

1073 metadata[f"lsst::arrow::shape::{name}".encode()] = str(dtype.shape) 

1074 

1075 

1076def _multidim_shape_from_metadata(metadata: dict[bytes, bytes], list_size: int, name: str) -> tuple[int, ...]: 

1077 """Retrieve the shape from the metadata, if available. 

1078 

1079 Parameters 

1080 ---------- 

1081 metadata : `dict` [`bytes`, `bytes`] 

1082 Metadata dictionary. 

1083 list_size : `int` 

1084 Size of the list datatype. 

1085 name : `str` 

1086 Column name. 

1087 

1088 Returns 

1089 ------- 

1090 shape : `tuple` [`int`] 

1091 Shape associated with the column. 

1092 

1093 Raises 

1094 ------ 

1095 RuntimeError 

1096 Raised if metadata is found but has incorrect format. 

1097 """ 

1098 md_name = f"lsst::arrow::shape::{name}" 

1099 if (encoded := md_name.encode("UTF-8")) in metadata: 

1100 groups = re.search(r"\((.*)\)", metadata[encoded].decode("UTF-8")) 

1101 if groups is None: 

1102 raise RuntimeError("Illegal value found in metadata.") 

1103 shape = tuple(int(x) for x in groups[1].split(",") if x != "") 

1104 else: 

1105 shape = (list_size,) 

1106 

1107 return shape 

1108 

1109 

1110def _schema_to_dtype_list(schema: pa.Schema) -> list[tuple[str, tuple[Any] | str]]: 

1111 """Convert a pyarrow schema to a numpy dtype. 

1112 

1113 Parameters 

1114 ---------- 

1115 schema : `pyarrow.Schema` 

1116 Input pyarrow schema. 

1117 

1118 Returns 

1119 ------- 

1120 dtype_list: `list` [`tuple`] 

1121 A list with name, type pairs. 

1122 """ 

1123 metadata = schema.metadata if schema.metadata is not None else {} 

1124 

1125 dtype: list[Any] = [] 

1126 for name in schema.names: 

1127 t = schema.field(name).type 

1128 if isinstance(t, pa.FixedSizeListType): 

1129 shape = _multidim_shape_from_metadata(metadata, t.list_size, name) 

1130 dtype.append((name, (t.value_type.to_pandas_dtype(), shape))) 

1131 elif t not in (pa.string(), pa.binary()): 

1132 dtype.append((name, t.to_pandas_dtype())) 

1133 else: 

1134 dtype.append((name, _arrow_string_to_numpy_dtype(schema, name))) 

1135 

1136 return dtype 

1137 

1138 

1139def _numpy_dtype_to_arrow_types(dtype: np.dtype) -> list[Any]: 

1140 """Convert a numpy dtype to a list of arrow types. 

1141 

1142 Parameters 

1143 ---------- 

1144 dtype : `numpy.dtype` 

1145 Numpy dtype to convert. 

1146 

1147 Returns 

1148 ------- 

1149 type_list : `list` [`object`] 

1150 Converted list of arrow types. 

1151 """ 

1152 from math import prod 

1153 

1154 import numpy as np 

1155 

1156 type_list: list[Any] = [] 

1157 if dtype.names is None: 

1158 return type_list 

1159 

1160 for name in dtype.names: 

1161 dt = dtype[name] 

1162 arrow_type: Any 

1163 if len(dt.shape) > 0: 

1164 arrow_type = pa.list_( 

1165 pa.from_numpy_dtype(cast(tuple[np.dtype, tuple[int, ...]], dt.subdtype)[0].type), 

1166 prod(dt.shape), 

1167 ) 

1168 else: 

1169 arrow_type = pa.from_numpy_dtype(dt.type) 

1170 type_list.append((name, arrow_type)) 

1171 

1172 return type_list 

1173 

1174 

1175def _numpy_dict_to_dtype(numpy_dict: dict[str, np.ndarray]) -> tuple[np.dtype, int]: 

1176 """Extract equivalent table dtype from dict of numpy arrays. 

1177 

1178 Parameters 

1179 ---------- 

1180 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

1181 Dict with keys as the column names, values as the arrays. 

1182 

1183 Returns 

1184 ------- 

1185 dtype : `numpy.dtype` 

1186 dtype of equivalent table. 

1187 rowcount : `int` 

1188 Number of rows in the table. 

1189 

1190 Raises 

1191 ------ 

1192 ValueError if columns in numpy_dict have unequal numbers of rows. 

1193 """ 

1194 import numpy as np 

1195 

1196 dtype_list = [] 

1197 rowcount = 0 

1198 for name, col in numpy_dict.items(): 

1199 if rowcount == 0: 

1200 rowcount = len(col) 

1201 if len(col) != rowcount: 

1202 raise ValueError(f"Column {name} has a different number of rows.") 

1203 if len(col.shape) == 1: 

1204 dtype_list.append((name, col.dtype)) 

1205 else: 

1206 dtype_list.append((name, (col.dtype, col.shape[1:]))) 

1207 dtype = np.dtype(dtype_list) 

1208 

1209 return (dtype, rowcount) 

1210 

1211 

1212def _numpy_style_arrays_to_arrow_arrays( 

1213 dtype: np.dtype, 

1214 rowcount: int, 

1215 np_style_arrays: dict[str, np.ndarray] | np.ndarray | atable.Table, 

1216 schema: pa.Schema, 

1217) -> list[pa.Array]: 

1218 """Convert numpy-style arrays to arrow arrays. 

1219 

1220 Parameters 

1221 ---------- 

1222 dtype : `numpy.dtype` 

1223 Numpy dtype of input table/arrays. 

1224 rowcount : `int` 

1225 Number of rows in input table/arrays. 

1226 np_style_arrays : `dict` [`str`, `np.ndarray`] or `np.ndarray` 

1227 or `astropy.table.Table` 

1228 Arrays to convert to arrow. 

1229 schema : `pyarrow.Schema` 

1230 Schema of arrow table. 

1231 

1232 Returns 

1233 ------- 

1234 arrow_arrays : `list` [`pyarrow.Array`] 

1235 List of converted pyarrow arrays. 

1236 """ 

1237 import numpy as np 

1238 

1239 arrow_arrays: list[pa.Array] = [] 

1240 if dtype.names is None: 

1241 return arrow_arrays 

1242 

1243 for name in dtype.names: 

1244 dt = dtype[name] 

1245 val: Any 

1246 if len(dt.shape) > 0: 

1247 if rowcount > 0: 

1248 val = np.split(np_style_arrays[name].ravel(), rowcount) 

1249 else: 

1250 val = [] 

1251 else: 

1252 val = np_style_arrays[name] 

1253 

1254 try: 

1255 arrow_arrays.append(pa.array(val, type=schema.field(name).type)) 

1256 except pa.ArrowNotImplementedError as err: 

1257 # Check if val is big-endian. 

1258 if (np.little_endian and val.dtype.byteorder == ">") or ( 

1259 not np.little_endian and val.dtype.byteorder == "=" 

1260 ): 

1261 # We need to convert the array to little-endian. 

1262 val2 = val.byteswap() 

1263 val2.dtype = val2.dtype.newbyteorder("<") 

1264 arrow_arrays.append(pa.array(val2, type=schema.field(name).type)) 

1265 else: 

1266 # This failed for some other reason so raise the exception. 

1267 raise err 

1268 

1269 return arrow_arrays 

1270 

1271 

1272def compute_row_group_size(schema: pa.Schema, target_size: int = TARGET_ROW_GROUP_BYTES) -> int: 

1273 """Compute approximate row group size for a given arrow schema. 

1274 

1275 Given a schema, this routine will compute the number of rows in a row group 

1276 that targets the persisted size on disk (or smaller). The exact size on 

1277 disk depends on the compression settings and ratios; typical binary data 

1278 tables will have around 15-20% compression with the pyarrow default 

1279 ``snappy`` compression algorithm. 

1280 

1281 Parameters 

1282 ---------- 

1283 schema : `pyarrow.Schema` 

1284 Arrow table schema. 

1285 target_size : `int`, optional 

1286 The target size (in bytes). 

1287 

1288 Returns 

1289 ------- 

1290 row_group_size : `int` 

1291 Number of rows per row group to hit the target size. 

1292 """ 

1293 bit_width = 0 

1294 

1295 metadata = schema.metadata if schema.metadata is not None else {} 

1296 

1297 for name in schema.names: 

1298 t = schema.field(name).type 

1299 

1300 if t in (pa.string(), pa.binary()): 

1301 md_name = f"lsst::arrow::len::{name}" 

1302 

1303 if (encoded := md_name.encode("UTF-8")) in metadata: 

1304 # String/bytes length from header. 

1305 strlen = int(schema.metadata[encoded]) 

1306 else: 

1307 # We don't know the string width, so guess something. 

1308 strlen = 10 

1309 

1310 # Assuming UTF-8 encoding, and very few wide characters. 

1311 t_width = 8 * strlen 

1312 elif isinstance(t, pa.FixedSizeListType): 

1313 if t.value_type == pa.null(): 

1314 t_width = 0 

1315 else: 

1316 t_width = t.list_size * t.value_type.bit_width 

1317 elif t == pa.null(): 

1318 t_width = 0 

1319 elif isinstance(t, pa.ListType): 

1320 if t.value_type == pa.null(): 

1321 t_width = 0 

1322 else: 

1323 # This is a variable length list, just choose 

1324 # something arbitrary. 

1325 t_width = 10 * t.value_type.bit_width 

1326 else: 

1327 t_width = t.bit_width 

1328 

1329 bit_width += t_width 

1330 

1331 # Insist it is at least 1 byte wide to avoid any divide-by-zero errors. 

1332 if bit_width < 8: 

1333 bit_width = 8 

1334 

1335 byte_width = bit_width // 8 

1336 

1337 return target_size // byte_width