Coverage for python/lsst/daf/butler/formatters/parquet.py: 12%

444 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-28 10:10 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ( 

25 "ParquetFormatter", 

26 "arrow_to_pandas", 

27 "arrow_to_astropy", 

28 "arrow_to_numpy", 

29 "arrow_to_numpy_dict", 

30 "pandas_to_arrow", 

31 "pandas_to_astropy", 

32 "astropy_to_arrow", 

33 "numpy_to_arrow", 

34 "numpy_to_astropy", 

35 "numpy_dict_to_arrow", 

36 "arrow_schema_to_pandas_index", 

37 "DataFrameSchema", 

38 "ArrowAstropySchema", 

39 "ArrowNumpySchema", 

40 "compute_row_group_size", 

41) 

42 

43import collections.abc 

44import itertools 

45import json 

46import re 

47from collections.abc import Iterable, Sequence 

48from typing import TYPE_CHECKING, Any, cast 

49 

50import pyarrow as pa 

51import pyarrow.parquet as pq 

52from lsst.daf.butler import Formatter 

53from lsst.utils.introspection import get_full_type_name 

54from lsst.utils.iteration import ensure_iterable 

55 

56if TYPE_CHECKING: 

57 import astropy.table as atable 

58 import numpy as np 

59 import pandas as pd 

60 

61TARGET_ROW_GROUP_BYTES = 1_000_000_000 

62 

63 

64class ParquetFormatter(Formatter): 

65 """Interface for reading and writing Arrow Table objects to and from 

66 Parquet files. 

67 """ 

68 

69 extension = ".parq" 

70 

71 def read(self, component: str | None = None) -> Any: 

72 # Docstring inherited from Formatter.read. 

73 schema = pq.read_schema(self.fileDescriptor.location.path) 

74 

75 if component in ("columns", "schema"): 

76 # The schema will be translated to column format 

77 # depending on the input type. 

78 return schema 

79 elif component == "rowcount": 

80 # Get the rowcount from the metadata if possible, otherwise count. 

81 if b"lsst::arrow::rowcount" in schema.metadata: 

82 return int(schema.metadata[b"lsst::arrow::rowcount"]) 

83 

84 temp_table = pq.read_table( 

85 self.fileDescriptor.location.path, 

86 columns=[schema.names[0]], 

87 use_threads=False, 

88 use_pandas_metadata=False, 

89 ) 

90 

91 return len(temp_table[schema.names[0]]) 

92 

93 par_columns = None 

94 if self.fileDescriptor.parameters: 

95 par_columns = self.fileDescriptor.parameters.pop("columns", None) 

96 if par_columns: 

97 has_pandas_multi_index = False 

98 if b"pandas" in schema.metadata: 

99 md = json.loads(schema.metadata[b"pandas"]) 

100 if len(md["column_indexes"]) > 1: 

101 has_pandas_multi_index = True 

102 

103 if not has_pandas_multi_index: 

104 # Ensure uniqueness, keeping order. 

105 par_columns = list(dict.fromkeys(ensure_iterable(par_columns))) 

106 file_columns = [name for name in schema.names if not name.startswith("__")] 

107 

108 for par_column in par_columns: 

109 if par_column not in file_columns: 

110 raise ValueError( 

111 f"Column {par_column} specified in parameters not available in parquet file." 

112 ) 

113 else: 

114 par_columns = _standardize_multi_index_columns( 

115 arrow_schema_to_pandas_index(schema), 

116 par_columns, 

117 ) 

118 

119 if len(self.fileDescriptor.parameters): 

120 raise ValueError( 

121 f"Unsupported parameters {self.fileDescriptor.parameters} in ArrowTable read." 

122 ) 

123 

124 metadata = schema.metadata if schema.metadata is not None else {} 

125 arrow_table = pq.read_table( 

126 self.fileDescriptor.location.path, 

127 columns=par_columns, 

128 use_threads=False, 

129 use_pandas_metadata=(b"pandas" in metadata), 

130 ) 

131 

132 return arrow_table 

133 

134 def write(self, inMemoryDataset: Any) -> None: 

135 import numpy as np 

136 from astropy.table import Table as astropyTable 

137 

138 arrow_table = None 

139 if isinstance(inMemoryDataset, pa.Table): 

140 # This will be the most likely match. 

141 arrow_table = inMemoryDataset 

142 elif isinstance(inMemoryDataset, astropyTable): 

143 arrow_table = astropy_to_arrow(inMemoryDataset) 

144 elif isinstance(inMemoryDataset, np.ndarray): 

145 arrow_table = numpy_to_arrow(inMemoryDataset) 

146 elif isinstance(inMemoryDataset, dict): 

147 try: 

148 arrow_table = numpy_dict_to_arrow(inMemoryDataset) 

149 except (TypeError, AttributeError) as e: 

150 raise ValueError( 

151 "Input dict for inMemoryDataset does not appear to be a dict of numpy arrays." 

152 ) from e 

153 else: 

154 if hasattr(inMemoryDataset, "to_parquet"): 

155 # This may be a pandas DataFrame 

156 try: 

157 import pandas as pd 

158 except ImportError: 

159 pd = None 

160 

161 if pd is not None and isinstance(inMemoryDataset, pd.DataFrame): 

162 arrow_table = pandas_to_arrow(inMemoryDataset) 

163 

164 if arrow_table is None: 

165 raise ValueError( 

166 f"Unsupported type {get_full_type_name(inMemoryDataset)} of " 

167 "inMemoryDataset for ParquetFormatter." 

168 ) 

169 

170 row_group_size = compute_row_group_size(arrow_table.schema) 

171 

172 location = self.makeUpdatedLocation(self.fileDescriptor.location) 

173 

174 pq.write_table(arrow_table, location.path, row_group_size=row_group_size) 

175 

176 

177def arrow_to_pandas(arrow_table: pa.Table) -> pd.DataFrame: 

178 """Convert a pyarrow table to a pandas DataFrame. 

179 

180 Parameters 

181 ---------- 

182 arrow_table : `pyarrow.Table` 

183 Input arrow table to convert. If the table has ``pandas`` metadata 

184 in the schema it will be used in the construction of the 

185 ``DataFrame``. 

186 

187 Returns 

188 ------- 

189 dataframe : `pandas.DataFrame` 

190 Converted pandas dataframe. 

191 """ 

192 return arrow_table.to_pandas(use_threads=False, integer_object_nulls=True) 

193 

194 

195def arrow_to_astropy(arrow_table: pa.Table) -> atable.Table: 

196 """Convert a pyarrow table to an `astropy.Table`. 

197 

198 Parameters 

199 ---------- 

200 arrow_table : `pyarrow.Table` 

201 Input arrow table to convert. If the table has astropy unit 

202 metadata in the schema it will be used in the construction 

203 of the ``astropy.Table``. 

204 

205 Returns 

206 ------- 

207 table : `astropy.Table` 

208 Converted astropy table. 

209 """ 

210 from astropy.table import Table 

211 

212 astropy_table = Table(arrow_to_numpy_dict(arrow_table)) 

213 

214 metadata = arrow_table.schema.metadata if arrow_table.schema.metadata is not None else {} 

215 

216 _apply_astropy_metadata(astropy_table, metadata) 

217 

218 return astropy_table 

219 

220 

221def arrow_to_numpy(arrow_table: pa.Table) -> np.ndarray: 

222 """Convert a pyarrow table to a structured numpy array. 

223 

224 Parameters 

225 ---------- 

226 arrow_table : `pyarrow.Table` 

227 Input arrow table. 

228 

229 Returns 

230 ------- 

231 array : `numpy.ndarray` (N,) 

232 Numpy array table with N rows and the same column names 

233 as the input arrow table. 

234 """ 

235 import numpy as np 

236 

237 numpy_dict = arrow_to_numpy_dict(arrow_table) 

238 

239 dtype = [] 

240 for name, col in numpy_dict.items(): 

241 if len(shape := numpy_dict[name].shape) <= 1: 

242 dtype.append((name, col.dtype)) 

243 else: 

244 dtype.append((name, (col.dtype, shape[1:]))) 

245 

246 array = np.rec.fromarrays(numpy_dict.values(), dtype=dtype) 

247 

248 return array 

249 

250 

251def arrow_to_numpy_dict(arrow_table: pa.Table) -> dict[str, np.ndarray]: 

252 """Convert a pyarrow table to a dict of numpy arrays. 

253 

254 Parameters 

255 ---------- 

256 arrow_table : `pyarrow.Table` 

257 Input arrow table. 

258 

259 Returns 

260 ------- 

261 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

262 Dict with keys as the column names, values as the arrays. 

263 """ 

264 import numpy as np 

265 

266 schema = arrow_table.schema 

267 metadata = schema.metadata if schema.metadata is not None else {} 

268 

269 numpy_dict = {} 

270 

271 for name in schema.names: 

272 t = schema.field(name).type 

273 

274 if arrow_table[name].null_count == 0: 

275 # Regular non-masked column 

276 col = arrow_table[name].to_numpy() 

277 else: 

278 # For a masked column, we need to ask arrow to fill the null 

279 # values with an appropriately typed value before conversion. 

280 # Then we apply the mask to get a masked array of the correct type. 

281 

282 if t in (pa.string(), pa.binary()): 

283 dummy = "" 

284 else: 

285 dummy = t.to_pandas_dtype()(0) 

286 

287 col = np.ma.masked_array( 

288 data=arrow_table[name].fill_null(dummy).to_numpy(), 

289 mask=arrow_table[name].is_null().to_numpy(), 

290 ) 

291 

292 if t in (pa.string(), pa.binary()): 

293 col = col.astype(_arrow_string_to_numpy_dtype(schema, name, col)) 

294 elif isinstance(t, pa.FixedSizeListType): 

295 if len(col) > 0: 

296 col = np.stack(col) 

297 else: 

298 # this is an empty column, and needs to be coerced to type. 

299 col = col.astype(t.value_type.to_pandas_dtype()) 

300 

301 shape = _multidim_shape_from_metadata(metadata, t.list_size, name) 

302 col = col.reshape((len(arrow_table), *shape)) 

303 

304 numpy_dict[name] = col 

305 

306 return numpy_dict 

307 

308 

309def _numpy_dict_to_numpy(numpy_dict: dict[str, np.ndarray]) -> np.ndarray: 

310 """Convert a dict of numpy arrays to a structured numpy array. 

311 

312 Parameters 

313 ---------- 

314 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

315 Dict with keys as the column names, values as the arrays. 

316 

317 Returns 

318 ------- 

319 array : `numpy.ndarray` (N,) 

320 Numpy array table with N rows and columns names from the dict keys. 

321 """ 

322 return arrow_to_numpy(numpy_dict_to_arrow(numpy_dict)) 

323 

324 

325def _numpy_to_numpy_dict(np_array: np.ndarray) -> dict[str, np.ndarray]: 

326 """Convert a structured numpy array to a dict of numpy arrays. 

327 

328 Parameters 

329 ---------- 

330 np_array : `numpy.ndarray` 

331 Input numpy array with multiple fields. 

332 

333 Returns 

334 ------- 

335 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

336 Dict with keys as the column names, values as the arrays. 

337 """ 

338 return arrow_to_numpy_dict(numpy_to_arrow(np_array)) 

339 

340 

341def numpy_to_arrow(np_array: np.ndarray) -> pa.Table: 

342 """Convert a numpy array table to an arrow table. 

343 

344 Parameters 

345 ---------- 

346 np_array : `numpy.ndarray` 

347 Input numpy array with multiple fields. 

348 

349 Returns 

350 ------- 

351 arrow_table : `pyarrow.Table` 

352 Converted arrow table. 

353 """ 

354 type_list = _numpy_dtype_to_arrow_types(np_array.dtype) 

355 

356 md = {} 

357 md[b"lsst::arrow::rowcount"] = str(len(np_array)) 

358 

359 for name in np_array.dtype.names: 

360 _append_numpy_string_metadata(md, name, np_array.dtype[name]) 

361 _append_numpy_multidim_metadata(md, name, np_array.dtype[name]) 

362 

363 schema = pa.schema(type_list, metadata=md) 

364 

365 arrays = _numpy_style_arrays_to_arrow_arrays( 

366 np_array.dtype, 

367 len(np_array), 

368 np_array, 

369 schema, 

370 ) 

371 

372 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

373 

374 return arrow_table 

375 

376 

377def numpy_dict_to_arrow(numpy_dict: dict[str, np.ndarray]) -> pa.Table: 

378 """Convert a dict of numpy arrays to an arrow table. 

379 

380 Parameters 

381 ---------- 

382 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

383 Dict with keys as the column names, values as the arrays. 

384 

385 Returns 

386 ------- 

387 arrow_table : `pyarrow.Table` 

388 Converted arrow table. 

389 

390 Raises 

391 ------ 

392 ValueError if columns in numpy_dict have unequal numbers of rows. 

393 """ 

394 dtype, rowcount = _numpy_dict_to_dtype(numpy_dict) 

395 type_list = _numpy_dtype_to_arrow_types(dtype) 

396 

397 md = {} 

398 md[b"lsst::arrow::rowcount"] = str(rowcount) 

399 

400 if dtype.names is not None: 

401 for name in dtype.names: 

402 _append_numpy_string_metadata(md, name, dtype[name]) 

403 _append_numpy_multidim_metadata(md, name, dtype[name]) 

404 

405 schema = pa.schema(type_list, metadata=md) 

406 

407 arrays = _numpy_style_arrays_to_arrow_arrays( 

408 dtype, 

409 rowcount, 

410 numpy_dict, 

411 schema, 

412 ) 

413 

414 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

415 

416 return arrow_table 

417 

418 

419def astropy_to_arrow(astropy_table: atable.Table) -> pa.Table: 

420 """Convert an astropy table to an arrow table. 

421 

422 Parameters 

423 ---------- 

424 astropy_table : `astropy.Table` 

425 Input astropy table. 

426 

427 Returns 

428 ------- 

429 arrow_table : `pyarrow.Table` 

430 Converted arrow table. 

431 """ 

432 from astropy.table import meta 

433 

434 type_list = _numpy_dtype_to_arrow_types(astropy_table.dtype) 

435 

436 md = {} 

437 md[b"lsst::arrow::rowcount"] = str(len(astropy_table)) 

438 

439 for name in astropy_table.dtype.names: 

440 _append_numpy_string_metadata(md, name, astropy_table.dtype[name]) 

441 _append_numpy_multidim_metadata(md, name, astropy_table.dtype[name]) 

442 

443 meta_yaml = meta.get_yaml_from_table(astropy_table) 

444 meta_yaml_str = "\n".join(meta_yaml) 

445 md[b"table_meta_yaml"] = meta_yaml_str 

446 

447 schema = pa.schema(type_list, metadata=md) 

448 

449 arrays = _numpy_style_arrays_to_arrow_arrays( 

450 astropy_table.dtype, 

451 len(astropy_table), 

452 astropy_table, 

453 schema, 

454 ) 

455 

456 arrow_table = pa.Table.from_arrays(arrays, schema=schema) 

457 

458 return arrow_table 

459 

460 

461def _astropy_to_numpy_dict(astropy_table: atable.Table) -> dict[str, np.ndarray]: 

462 """Convert an astropy table to an arrow table. 

463 

464 Parameters 

465 ---------- 

466 astropy_table : `astropy.Table` 

467 Input astropy table. 

468 

469 Returns 

470 ------- 

471 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

472 Dict with keys as the column names, values as the arrays. 

473 """ 

474 return arrow_to_numpy_dict(astropy_to_arrow(astropy_table)) 

475 

476 

477def pandas_to_arrow(dataframe: pd.DataFrame, default_length: int = 10) -> pa.Table: 

478 """Convert a pandas dataframe to an arrow table. 

479 

480 Parameters 

481 ---------- 

482 dataframe : `pandas.DataFrame` 

483 Input pandas dataframe. 

484 default_length : `int`, optional 

485 Default string length when not in metadata or can be inferred 

486 from column. 

487 

488 Returns 

489 ------- 

490 arrow_table : `pyarrow.Table` 

491 Converted arrow table. 

492 """ 

493 arrow_table = pa.Table.from_pandas(dataframe) 

494 

495 # Update the metadata 

496 md = arrow_table.schema.metadata 

497 

498 md[b"lsst::arrow::rowcount"] = str(arrow_table.num_rows) 

499 

500 # We loop through the arrow table columns because the datatypes have 

501 # been checked and converted from pandas objects. 

502 for name in arrow_table.column_names: 

503 if not name.startswith("__"): 

504 if arrow_table[name].type == pa.string(): 

505 if len(arrow_table[name]) > 0: 

506 strlen = max(len(row.as_py()) for row in arrow_table[name] if row.is_valid) 

507 else: 

508 strlen = default_length 

509 md[f"lsst::arrow::len::{name}".encode()] = str(strlen) 

510 

511 arrow_table = arrow_table.replace_schema_metadata(md) 

512 

513 return arrow_table 

514 

515 

516def pandas_to_astropy(dataframe: pd.DataFrame) -> atable.Table: 

517 """Convert a pandas dataframe to an astropy table, preserving indexes. 

518 

519 Parameters 

520 ---------- 

521 dataframe : `pandas.DataFrame` 

522 Input pandas dataframe. 

523 

524 Returns 

525 ------- 

526 astropy_table : `astropy.table.Table` 

527 Converted astropy table. 

528 """ 

529 import pandas as pd 

530 from astropy.table import Table 

531 

532 if isinstance(dataframe.columns, pd.MultiIndex): 

533 raise ValueError("Cannot convert a multi-index dataframe to an astropy table.") 

534 

535 return Table.from_pandas(dataframe, index=True) 

536 

537 

538def _pandas_to_numpy_dict(dataframe: pd.DataFrame) -> dict[str, np.ndarray]: 

539 """Convert a pandas dataframe to an dict of numpy arrays. 

540 

541 Parameters 

542 ---------- 

543 dataframe : `pandas.DataFrame` 

544 Input pandas dataframe. 

545 

546 Returns 

547 ------- 

548 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

549 Dict with keys as the column names, values as the arrays. 

550 """ 

551 return arrow_to_numpy_dict(pandas_to_arrow(dataframe)) 

552 

553 

554def numpy_to_astropy(np_array: np.ndarray) -> atable.Table: 

555 """Convert a numpy table to an astropy table. 

556 

557 Parameters 

558 ---------- 

559 np_array : `numpy.ndarray` 

560 Input numpy array with multiple fields. 

561 

562 Returns 

563 ------- 

564 astropy_table : `astropy.table.Table` 

565 Converted astropy table. 

566 """ 

567 from astropy.table import Table 

568 

569 return Table(data=np_array, copy=False) 

570 

571 

572def arrow_schema_to_pandas_index(schema: pa.Schema) -> pd.Index | pd.MultiIndex: 

573 """Convert an arrow schema to a pandas index/multiindex. 

574 

575 Parameters 

576 ---------- 

577 schema : `pyarrow.Schema` 

578 Input pyarrow schema. 

579 

580 Returns 

581 ------- 

582 index : `pandas.Index` or `pandas.MultiIndex` 

583 Converted pandas index. 

584 """ 

585 import pandas as pd 

586 

587 if b"pandas" in schema.metadata: 

588 md = json.loads(schema.metadata[b"pandas"]) 

589 indexes = md["column_indexes"] 

590 len_indexes = len(indexes) 

591 else: 

592 len_indexes = 0 

593 

594 if len_indexes <= 1: 

595 return pd.Index(name for name in schema.names if not name.startswith("__")) 

596 else: 

597 raw_columns = _split_multi_index_column_names(len(indexes), schema.names) 

598 return pd.MultiIndex.from_tuples(raw_columns, names=[f["name"] for f in indexes]) 

599 

600 

601def arrow_schema_to_column_list(schema: pa.Schema) -> list[str]: 

602 """Convert an arrow schema to a list of string column names. 

603 

604 Parameters 

605 ---------- 

606 schema : `pyarrow.Schema` 

607 Input pyarrow schema. 

608 

609 Returns 

610 ------- 

611 column_list : `list` [`str`] 

612 Converted list of column names. 

613 """ 

614 return [name for name in schema.names] 

615 

616 

617class DataFrameSchema: 

618 """Wrapper class for a schema for a pandas DataFrame. 

619 

620 Parameters 

621 ---------- 

622 dataframe : `pandas.DataFrame` 

623 Dataframe to turn into a schema. 

624 """ 

625 

626 def __init__(self, dataframe: pd.DataFrame) -> None: 

627 self._schema = dataframe.loc[[False] * len(dataframe)] 

628 

629 @classmethod 

630 def from_arrow(cls, schema: pa.Schema) -> DataFrameSchema: 

631 """Convert an arrow schema into a `DataFrameSchema`. 

632 

633 Parameters 

634 ---------- 

635 schema : `pyarrow.Schema` 

636 The pyarrow schema to convert. 

637 

638 Returns 

639 ------- 

640 dataframe_schema : `DataFrameSchema` 

641 Converted dataframe schema. 

642 """ 

643 empty_table = pa.Table.from_pylist([] * len(schema.names), schema=schema) 

644 

645 return cls(empty_table.to_pandas()) 

646 

647 def to_arrow_schema(self) -> pa.Schema: 

648 """Convert to an arrow schema. 

649 

650 Returns 

651 ------- 

652 arrow_schema : `pyarrow.Schema` 

653 Converted pyarrow schema. 

654 """ 

655 arrow_table = pa.Table.from_pandas(self._schema) 

656 

657 return arrow_table.schema 

658 

659 def to_arrow_numpy_schema(self) -> ArrowNumpySchema: 

660 """Convert to an `ArrowNumpySchema`. 

661 

662 Returns 

663 ------- 

664 arrow_numpy_schema : `ArrowNumpySchema` 

665 Converted arrow numpy schema. 

666 """ 

667 return ArrowNumpySchema.from_arrow(self.to_arrow_schema()) 

668 

669 def to_arrow_astropy_schema(self) -> ArrowAstropySchema: 

670 """Convert to an ArrowAstropySchema. 

671 

672 Returns 

673 ------- 

674 arrow_astropy_schema : `ArrowAstropySchema` 

675 Converted arrow astropy schema. 

676 """ 

677 return ArrowAstropySchema.from_arrow(self.to_arrow_schema()) 

678 

679 @property 

680 def schema(self) -> np.dtype: 

681 return self._schema 

682 

683 def __repr__(self) -> str: 

684 return repr(self._schema) 

685 

686 def __eq__(self, other: object) -> bool: 

687 if not isinstance(other, DataFrameSchema): 

688 return NotImplemented 

689 

690 return self._schema.equals(other._schema) 

691 

692 

693class ArrowAstropySchema: 

694 """Wrapper class for a schema for an astropy table. 

695 

696 Parameters 

697 ---------- 

698 astropy_table : `astropy.table.Table` 

699 Input astropy table. 

700 """ 

701 

702 def __init__(self, astropy_table: atable.Table) -> None: 

703 self._schema = astropy_table[:0] 

704 

705 @classmethod 

706 def from_arrow(cls, schema: pa.Schema) -> ArrowAstropySchema: 

707 """Convert an arrow schema into a ArrowAstropySchema. 

708 

709 Parameters 

710 ---------- 

711 schema : `pyarrow.Schema` 

712 Input pyarrow schema. 

713 

714 Returns 

715 ------- 

716 astropy_schema : `ArrowAstropySchema` 

717 Converted arrow astropy schema. 

718 """ 

719 import numpy as np 

720 from astropy.table import Table 

721 

722 dtype = _schema_to_dtype_list(schema) 

723 

724 data = np.zeros(0, dtype=dtype) 

725 astropy_table = Table(data=data) 

726 

727 metadata = schema.metadata if schema.metadata is not None else {} 

728 

729 _apply_astropy_metadata(astropy_table, metadata) 

730 

731 return cls(astropy_table) 

732 

733 def to_arrow_schema(self) -> pa.Schema: 

734 """Convert to an arrow schema. 

735 

736 Returns 

737 ------- 

738 arrow_schema : `pyarrow.Schema` 

739 Converted pyarrow schema. 

740 """ 

741 return astropy_to_arrow(self._schema).schema 

742 

743 def to_dataframe_schema(self) -> DataFrameSchema: 

744 """Convert to a DataFrameSchema. 

745 

746 Returns 

747 ------- 

748 dataframe_schema : `DataFrameSchema` 

749 Converted dataframe schema. 

750 """ 

751 return DataFrameSchema.from_arrow(astropy_to_arrow(self._schema).schema) 

752 

753 def to_arrow_numpy_schema(self) -> ArrowNumpySchema: 

754 """Convert to an `ArrowNumpySchema`. 

755 

756 Returns 

757 ------- 

758 arrow_numpy_schema : `ArrowNumpySchema` 

759 Converted arrow numpy schema. 

760 """ 

761 return ArrowNumpySchema.from_arrow(astropy_to_arrow(self._schema).schema) 

762 

763 @property 

764 def schema(self) -> atable.Table: 

765 return self._schema 

766 

767 def __repr__(self) -> str: 

768 return repr(self._schema) 

769 

770 def __eq__(self, other: object) -> bool: 

771 if not isinstance(other, ArrowAstropySchema): 

772 return NotImplemented 

773 

774 # If this comparison passes then the two tables have the 

775 # same column names. 

776 if self._schema.dtype != other._schema.dtype: 

777 return False 

778 

779 for name in self._schema.columns: 

780 if not self._schema[name].unit == other._schema[name].unit: 

781 return False 

782 if not self._schema[name].description == other._schema[name].description: 

783 return False 

784 if not self._schema[name].format == other._schema[name].format: 

785 return False 

786 

787 return True 

788 

789 

790class ArrowNumpySchema: 

791 """Wrapper class for a schema for a numpy ndarray. 

792 

793 Parameters 

794 ---------- 

795 numpy_dtype : `numpy.dtype` 

796 Numpy dtype to convert. 

797 """ 

798 

799 def __init__(self, numpy_dtype: np.dtype) -> None: 

800 self._dtype = numpy_dtype 

801 

802 @classmethod 

803 def from_arrow(cls, schema: pa.Schema) -> ArrowNumpySchema: 

804 """Convert an arrow schema into an `ArrowNumpySchema`. 

805 

806 Parameters 

807 ---------- 

808 schema : `pyarrow.Schema` 

809 Pyarrow schema to convert. 

810 

811 Returns 

812 ------- 

813 numpy_schema : `ArrowNumpySchema` 

814 Converted arrow numpy schema. 

815 """ 

816 import numpy as np 

817 

818 dtype = _schema_to_dtype_list(schema) 

819 

820 return cls(np.dtype(dtype)) 

821 

822 def to_arrow_astropy_schema(self) -> ArrowAstropySchema: 

823 """Convert to an `ArrowAstropySchema`. 

824 

825 Returns 

826 ------- 

827 astropy_schema : `ArrowAstropySchema` 

828 Converted arrow astropy schema. 

829 """ 

830 import numpy as np 

831 

832 return ArrowAstropySchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema) 

833 

834 def to_dataframe_schema(self) -> DataFrameSchema: 

835 """Convert to a `DataFrameSchema`. 

836 

837 Returns 

838 ------- 

839 dataframe_schema : `DataFrameSchema` 

840 Converted dataframe schema. 

841 """ 

842 import numpy as np 

843 

844 return DataFrameSchema.from_arrow(numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema) 

845 

846 def to_arrow_schema(self) -> pa.Schema: 

847 """Convert to a `pyarrow.Schema`. 

848 

849 Returns 

850 ------- 

851 arrow_schema : `pyarrow.Schema` 

852 Converted pyarrow schema. 

853 """ 

854 import numpy as np 

855 

856 return numpy_to_arrow(np.zeros(0, dtype=self._dtype)).schema 

857 

858 @property 

859 def schema(self) -> np.dtype: 

860 return self._dtype 

861 

862 def __repr__(self) -> str: 

863 return repr(self._dtype) 

864 

865 def __eq__(self, other: object) -> bool: 

866 if not isinstance(other, ArrowNumpySchema): 

867 return NotImplemented 

868 

869 if not self._dtype == other._dtype: 

870 return False 

871 

872 return True 

873 

874 

875def _split_multi_index_column_names(n: int, names: Iterable[str]) -> list[Sequence[str]]: 

876 """Split a string that represents a multi-index column. 

877 

878 PyArrow maps Pandas' multi-index column names (which are tuples in Python) 

879 to flat strings on disk. This routine exists to reconstruct the original 

880 tuple. 

881 

882 Parameters 

883 ---------- 

884 n : `int` 

885 Number of levels in the `pandas.MultiIndex` that is being 

886 reconstructed. 

887 names : `~collections.abc.Iterable` [`str`] 

888 Strings to be split. 

889 

890 Returns 

891 ------- 

892 column_names : `list` [`tuple` [`str`]] 

893 A list of multi-index column name tuples. 

894 """ 

895 column_names: list[Sequence[str]] = [] 

896 

897 pattern = re.compile(r"\({}\)".format(", ".join(["'(.*)'"] * n))) 

898 for name in names: 

899 m = re.search(pattern, name) 

900 if m is not None: 

901 column_names.append(m.groups()) 

902 

903 return column_names 

904 

905 

906def _standardize_multi_index_columns( 

907 pd_index: pd.MultiIndex, 

908 columns: Any, 

909 stringify: bool = True, 

910) -> list[str | Sequence[Any]]: 

911 """Transform a dictionary/iterable index from a multi-index column list 

912 into a string directly understandable by PyArrow. 

913 

914 Parameters 

915 ---------- 

916 pd_index : `pandas.MultiIndex` 

917 Pandas multi-index. 

918 columns : `list` [`tuple`] or `dict` [`str`, `str` or `list` [`str`]] 

919 Columns to standardize. 

920 stringify : `bool`, optional 

921 Should the column names be stringified? 

922 

923 Returns 

924 ------- 

925 names : `list` [`str`] 

926 Stringified representation of a multi-index column name. 

927 """ 

928 index_level_names = tuple(pd_index.names) 

929 

930 names: list[str | Sequence[Any]] = [] 

931 

932 if isinstance(columns, list): 

933 for requested in columns: 

934 if not isinstance(requested, tuple): 

935 raise ValueError( 

936 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. " 

937 f"Instead got a {get_full_type_name(requested)}." 

938 ) 

939 if stringify: 

940 names.append(str(requested)) 

941 else: 

942 names.append(requested) 

943 else: 

944 if not isinstance(columns, collections.abc.Mapping): 

945 raise ValueError( 

946 "Columns parameter for multi-index data frame must be a dictionary or list of tuples. " 

947 f"Instead got a {get_full_type_name(columns)}." 

948 ) 

949 if not set(index_level_names).issuperset(columns.keys()): 

950 raise ValueError( 

951 f"Cannot use dict with keys {set(columns.keys())} to select columns from {index_level_names}." 

952 ) 

953 factors = [ 

954 ensure_iterable(columns.get(level, pd_index.levels[i])) 

955 for i, level in enumerate(index_level_names) 

956 ] 

957 for requested in itertools.product(*factors): 

958 for i, value in enumerate(requested): 

959 if value not in pd_index.levels[i]: 

960 raise ValueError(f"Unrecognized value {value!r} for index {index_level_names[i]!r}.") 

961 if stringify: 

962 names.append(str(requested)) 

963 else: 

964 names.append(requested) 

965 

966 return names 

967 

968 

969def _apply_astropy_metadata(astropy_table: atable.Table, metadata: dict) -> None: 

970 """Apply any astropy metadata from the schema metadata. 

971 

972 Parameters 

973 ---------- 

974 astropy_table : `astropy.table.Table` 

975 Table to apply metadata. 

976 metadata : `dict` [`bytes`] 

977 Metadata dict. 

978 """ 

979 from astropy.table import meta 

980 

981 meta_yaml = metadata.get(b"table_meta_yaml", None) 

982 if meta_yaml: 

983 meta_yaml = meta_yaml.decode("UTF8").split("\n") 

984 meta_hdr = meta.get_header_from_yaml(meta_yaml) 

985 

986 # Set description, format, unit, meta from the column 

987 # metadata that was serialized with the table. 

988 header_cols = {x["name"]: x for x in meta_hdr["datatype"]} 

989 for col in astropy_table.columns.values(): 

990 for attr in ("description", "format", "unit", "meta"): 

991 if attr in header_cols[col.name]: 

992 setattr(col, attr, header_cols[col.name][attr]) 

993 

994 if "meta" in meta_hdr: 

995 astropy_table.meta.update(meta_hdr["meta"]) 

996 

997 

998def _arrow_string_to_numpy_dtype( 

999 schema: pa.Schema, name: str, numpy_column: np.ndarray | None = None, default_length: int = 10 

1000) -> str: 

1001 """Get the numpy dtype string associated with an arrow column. 

1002 

1003 Parameters 

1004 ---------- 

1005 schema : `pyarrow.Schema` 

1006 Arrow table schema. 

1007 name : `str` 

1008 Column name. 

1009 numpy_column : `numpy.ndarray`, optional 

1010 Column to determine numpy string dtype. 

1011 default_length : `int`, optional 

1012 Default string length when not in metadata or can be inferred 

1013 from column. 

1014 

1015 Returns 

1016 ------- 

1017 dtype_str : `str` 

1018 Numpy dtype string. 

1019 """ 

1020 # Special-case for string and binary columns 

1021 md_name = f"lsst::arrow::len::{name}" 

1022 strlen = default_length 

1023 metadata = schema.metadata if schema.metadata is not None else {} 

1024 if (encoded := md_name.encode("UTF-8")) in metadata: 

1025 # String/bytes length from header. 

1026 strlen = int(schema.metadata[encoded]) 

1027 elif numpy_column is not None: 

1028 if len(numpy_column) > 0: 

1029 strlen = max(len(row) for row in numpy_column) 

1030 

1031 dtype = f"U{strlen}" if schema.field(name).type == pa.string() else f"|S{strlen}" 

1032 

1033 return dtype 

1034 

1035 

1036def _append_numpy_string_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None: 

1037 """Append numpy string length keys to arrow metadata. 

1038 

1039 All column types are handled, but the metadata is only modified for 

1040 string and byte columns. 

1041 

1042 Parameters 

1043 ---------- 

1044 metadata : `dict` [`bytes`, `str`] 

1045 Metadata dictionary; modified in place. 

1046 name : `str` 

1047 Column name. 

1048 dtype : `np.dtype` 

1049 Numpy dtype. 

1050 """ 

1051 import numpy as np 

1052 

1053 if dtype.type is np.str_: 

1054 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize // 4) 

1055 elif dtype.type is np.bytes_: 

1056 metadata[f"lsst::arrow::len::{name}".encode()] = str(dtype.itemsize) 

1057 

1058 

1059def _append_numpy_multidim_metadata(metadata: dict[bytes, str], name: str, dtype: np.dtype) -> None: 

1060 """Append numpy multi-dimensional shapes to arrow metadata. 

1061 

1062 All column types are handled, but the metadata is only modified for 

1063 multi-dimensional columns. 

1064 

1065 Parameters 

1066 ---------- 

1067 metadata : `dict` [`bytes`, `str`] 

1068 Metadata dictionary; modified in place. 

1069 name : `str` 

1070 Column name. 

1071 dtype : `np.dtype` 

1072 Numpy dtype. 

1073 """ 

1074 if len(dtype.shape) > 1: 

1075 metadata[f"lsst::arrow::shape::{name}".encode()] = str(dtype.shape) 

1076 

1077 

1078def _multidim_shape_from_metadata(metadata: dict[bytes, bytes], list_size: int, name: str) -> tuple[int, ...]: 

1079 """Retrieve the shape from the metadata, if available. 

1080 

1081 Parameters 

1082 ---------- 

1083 metadata : `dict` [`bytes`, `bytes`] 

1084 Metadata dictionary. 

1085 list_size : `int` 

1086 Size of the list datatype. 

1087 name : `str` 

1088 Column name. 

1089 

1090 Returns 

1091 ------- 

1092 shape : `tuple` [`int`] 

1093 Shape associated with the column. 

1094 

1095 Raises 

1096 ------ 

1097 RuntimeError 

1098 Raised if metadata is found but has incorrect format. 

1099 """ 

1100 md_name = f"lsst::arrow::shape::{name}" 

1101 if (encoded := md_name.encode("UTF-8")) in metadata: 

1102 groups = re.search(r"\((.*)\)", metadata[encoded].decode("UTF-8")) 

1103 if groups is None: 

1104 raise RuntimeError("Illegal value found in metadata.") 

1105 shape = tuple(int(x) for x in groups[1].split(",") if x != "") 

1106 else: 

1107 shape = (list_size,) 

1108 

1109 return shape 

1110 

1111 

1112def _schema_to_dtype_list(schema: pa.Schema) -> list[tuple[str, tuple[Any] | str]]: 

1113 """Convert a pyarrow schema to a numpy dtype. 

1114 

1115 Parameters 

1116 ---------- 

1117 schema : `pyarrow.Schema` 

1118 Input pyarrow schema. 

1119 

1120 Returns 

1121 ------- 

1122 dtype_list: `list` [`tuple`] 

1123 A list with name, type pairs. 

1124 """ 

1125 metadata = schema.metadata if schema.metadata is not None else {} 

1126 

1127 dtype: list[Any] = [] 

1128 for name in schema.names: 

1129 t = schema.field(name).type 

1130 if isinstance(t, pa.FixedSizeListType): 

1131 shape = _multidim_shape_from_metadata(metadata, t.list_size, name) 

1132 dtype.append((name, (t.value_type.to_pandas_dtype(), shape))) 

1133 elif t not in (pa.string(), pa.binary()): 

1134 dtype.append((name, t.to_pandas_dtype())) 

1135 else: 

1136 dtype.append((name, _arrow_string_to_numpy_dtype(schema, name))) 

1137 

1138 return dtype 

1139 

1140 

1141def _numpy_dtype_to_arrow_types(dtype: np.dtype) -> list[Any]: 

1142 """Convert a numpy dtype to a list of arrow types. 

1143 

1144 Parameters 

1145 ---------- 

1146 dtype : `numpy.dtype` 

1147 Numpy dtype to convert. 

1148 

1149 Returns 

1150 ------- 

1151 type_list : `list` [`object`] 

1152 Converted list of arrow types. 

1153 """ 

1154 from math import prod 

1155 

1156 import numpy as np 

1157 

1158 type_list: list[Any] = [] 

1159 if dtype.names is None: 

1160 return type_list 

1161 

1162 for name in dtype.names: 

1163 dt = dtype[name] 

1164 arrow_type: Any 

1165 if len(dt.shape) > 0: 

1166 arrow_type = pa.list_( 

1167 pa.from_numpy_dtype(cast(tuple[np.dtype, tuple[int, ...]], dt.subdtype)[0].type), 

1168 prod(dt.shape), 

1169 ) 

1170 else: 

1171 arrow_type = pa.from_numpy_dtype(dt.type) 

1172 type_list.append((name, arrow_type)) 

1173 

1174 return type_list 

1175 

1176 

1177def _numpy_dict_to_dtype(numpy_dict: dict[str, np.ndarray]) -> tuple[np.dtype, int]: 

1178 """Extract equivalent table dtype from dict of numpy arrays. 

1179 

1180 Parameters 

1181 ---------- 

1182 numpy_dict : `dict` [`str`, `numpy.ndarray`] 

1183 Dict with keys as the column names, values as the arrays. 

1184 

1185 Returns 

1186 ------- 

1187 dtype : `numpy.dtype` 

1188 dtype of equivalent table. 

1189 rowcount : `int` 

1190 Number of rows in the table. 

1191 

1192 Raises 

1193 ------ 

1194 ValueError if columns in numpy_dict have unequal numbers of rows. 

1195 """ 

1196 import numpy as np 

1197 

1198 dtype_list = [] 

1199 rowcount = 0 

1200 for name, col in numpy_dict.items(): 

1201 if rowcount == 0: 

1202 rowcount = len(col) 

1203 if len(col) != rowcount: 

1204 raise ValueError(f"Column {name} has a different number of rows.") 

1205 if len(col.shape) == 1: 

1206 dtype_list.append((name, col.dtype)) 

1207 else: 

1208 dtype_list.append((name, (col.dtype, col.shape[1:]))) 

1209 dtype = np.dtype(dtype_list) 

1210 

1211 return (dtype, rowcount) 

1212 

1213 

1214def _numpy_style_arrays_to_arrow_arrays( 

1215 dtype: np.dtype, 

1216 rowcount: int, 

1217 np_style_arrays: dict[str, np.ndarray] | np.ndarray | atable.Table, 

1218 schema: pa.Schema, 

1219) -> list[pa.Array]: 

1220 """Convert numpy-style arrays to arrow arrays. 

1221 

1222 Parameters 

1223 ---------- 

1224 dtype : `numpy.dtype` 

1225 Numpy dtype of input table/arrays. 

1226 rowcount : `int` 

1227 Number of rows in input table/arrays. 

1228 np_style_arrays : `dict` [`str`, `np.ndarray`] or `np.ndarray` 

1229 or `astropy.table.Table` 

1230 Arrays to convert to arrow. 

1231 schema : `pyarrow.Schema` 

1232 Schema of arrow table. 

1233 

1234 Returns 

1235 ------- 

1236 arrow_arrays : `list` [`pyarrow.Array`] 

1237 List of converted pyarrow arrays. 

1238 """ 

1239 import numpy as np 

1240 

1241 arrow_arrays: list[pa.Array] = [] 

1242 if dtype.names is None: 

1243 return arrow_arrays 

1244 

1245 for name in dtype.names: 

1246 dt = dtype[name] 

1247 val: Any 

1248 if len(dt.shape) > 0: 

1249 if rowcount > 0: 

1250 val = np.split(np_style_arrays[name].ravel(), rowcount) 

1251 else: 

1252 val = [] 

1253 else: 

1254 val = np_style_arrays[name] 

1255 

1256 try: 

1257 arrow_arrays.append(pa.array(val, type=schema.field(name).type)) 

1258 except pa.ArrowNotImplementedError as err: 

1259 # Check if val is big-endian. 

1260 if (np.little_endian and val.dtype.byteorder == ">") or ( 

1261 not np.little_endian and val.dtype.byteorder == "=" 

1262 ): 

1263 # We need to convert the array to little-endian. 

1264 val2 = val.byteswap() 

1265 val2.dtype = val2.dtype.newbyteorder("<") 

1266 arrow_arrays.append(pa.array(val2, type=schema.field(name).type)) 

1267 else: 

1268 # This failed for some other reason so raise the exception. 

1269 raise err 

1270 

1271 return arrow_arrays 

1272 

1273 

1274def compute_row_group_size(schema: pa.Schema, target_size: int = TARGET_ROW_GROUP_BYTES) -> int: 

1275 """Compute approximate row group size for a given arrow schema. 

1276 

1277 Given a schema, this routine will compute the number of rows in a row group 

1278 that targets the persisted size on disk (or smaller). The exact size on 

1279 disk depends on the compression settings and ratios; typical binary data 

1280 tables will have around 15-20% compression with the pyarrow default 

1281 ``snappy`` compression algorithm. 

1282 

1283 Parameters 

1284 ---------- 

1285 schema : `pyarrow.Schema` 

1286 Arrow table schema. 

1287 target_size : `int`, optional 

1288 The target size (in bytes). 

1289 

1290 Returns 

1291 ------- 

1292 row_group_size : `int` 

1293 Number of rows per row group to hit the target size. 

1294 """ 

1295 bit_width = 0 

1296 

1297 metadata = schema.metadata if schema.metadata is not None else {} 

1298 

1299 for name in schema.names: 

1300 t = schema.field(name).type 

1301 

1302 if t in (pa.string(), pa.binary()): 

1303 md_name = f"lsst::arrow::len::{name}" 

1304 

1305 if (encoded := md_name.encode("UTF-8")) in metadata: 

1306 # String/bytes length from header. 

1307 strlen = int(schema.metadata[encoded]) 

1308 else: 

1309 # We don't know the string width, so guess something. 

1310 strlen = 10 

1311 

1312 # Assuming UTF-8 encoding, and very few wide characters. 

1313 t_width = 8 * strlen 

1314 elif isinstance(t, pa.FixedSizeListType): 

1315 if t.value_type == pa.null(): 

1316 t_width = 0 

1317 else: 

1318 t_width = t.list_size * t.value_type.bit_width 

1319 elif t == pa.null(): 

1320 t_width = 0 

1321 elif isinstance(t, pa.ListType): 

1322 if t.value_type == pa.null(): 

1323 t_width = 0 

1324 else: 

1325 # This is a variable length list, just choose 

1326 # something arbitrary. 

1327 t_width = 10 * t.value_type.bit_width 

1328 else: 

1329 t_width = t.bit_width 

1330 

1331 bit_width += t_width 

1332 

1333 # Insist it is at least 1 byte wide to avoid any divide-by-zero errors. 

1334 if bit_width < 8: 

1335 bit_width = 8 

1336 

1337 byte_width = bit_width // 8 

1338 

1339 return target_size // byte_width