Coverage for python/lsst/daf/butler/arrow_utils.py: 78%

217 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-11 03:16 -0700

1# This file is part of butler4. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ( 

31 "ToArrow", 

32 "RegionArrowType", 

33 "RegionArrowScalar", 

34 "TimespanArrowType", 

35 "TimespanArrowScalar", 

36 "DateTimeArrowType", 

37 "DateTimeArrowScalar", 

38 "UUIDArrowType", 

39 "UUIDArrowScalar", 

40) 

41 

42import uuid 

43from abc import ABC, abstractmethod 

44from typing import Any, ClassVar, final 

45 

46import astropy.time 

47import pyarrow as pa 

48from lsst.sphgeom import Region 

49 

50from ._timespan import Timespan 

51from .time_utils import TimeConverter 

52 

53 

54class ToArrow(ABC): 

55 """An abstract interface for converting objects to an Arrow field of the 

56 appropriate type. 

57 """ 

58 

59 @staticmethod 

60 def for_primitive(name: str, data_type: pa.DataType, nullable: bool) -> ToArrow: 

61 """Return a converter for a primitive type already supported by Arrow. 

62 

63 Parameters 

64 ---------- 

65 name : `str` 

66 Name of the schema field. 

67 data_type : `pyarrow.DataType` 

68 Arrow data type object. 

69 nullable : `bool` 

70 Whether the field should permit null or `None` values. 

71 

72 Returns 

73 ------- 

74 to_arrow : `ToArrow` 

75 Converter instance. 

76 """ 

77 return _ToArrowPrimitive(name, data_type, nullable) 

78 

79 @staticmethod 

80 def for_uuid(name: str, nullable: bool = True) -> ToArrow: 

81 """Return a converter for `uuid.UUID`. 

82 

83 Parameters 

84 ---------- 

85 name : `str` 

86 Name of the schema field. 

87 nullable : `bool` 

88 Whether the field should permit null or `None` values. 

89 

90 Returns 

91 ------- 

92 to_arrow : `ToArrow` 

93 Converter instance. 

94 """ 

95 return _ToArrowUUID(name, nullable) 

96 

97 @staticmethod 

98 def for_region(name: str, nullable: bool = True) -> ToArrow: 

99 """Return a converter for `lsst.sphgeom.Region`. 

100 

101 Parameters 

102 ---------- 

103 name : `str` 

104 Name of the schema field. 

105 nullable : `bool` 

106 Whether the field should permit null or `None` values. 

107 

108 Returns 

109 ------- 

110 to_arrow : `ToArrow` 

111 Converter instance. 

112 """ 

113 return _ToArrowRegion(name, nullable) 

114 

115 @staticmethod 

116 def for_timespan(name: str, nullable: bool = True) -> ToArrow: 

117 """Return a converter for `lsst.daf.butler.Timespan`. 

118 

119 Parameters 

120 ---------- 

121 name : `str` 

122 Name of the schema field. 

123 nullable : `bool` 

124 Whether the field should permit null or `None` values. 

125 

126 Returns 

127 ------- 

128 to_arrow : `ToArrow` 

129 Converter instance. 

130 """ 

131 return _ToArrowTimespan(name, nullable) 

132 

133 @staticmethod 

134 def for_datetime(name: str, nullable: bool = True) -> ToArrow: 

135 """Return a converter for `astropy.time.Time`, stored as TAI 

136 nanoseconds since 1970-01-01. 

137 

138 Parameters 

139 ---------- 

140 name : `str` 

141 Name of the schema field. 

142 nullable : `bool` 

143 Whether the field should permit null or `None` values. 

144 

145 Returns 

146 ------- 

147 to_arrow : `ToArrow` 

148 Converter instance. 

149 """ 

150 return _ToArrowDateTime(name, nullable) 

151 

152 @property 

153 @abstractmethod 

154 def name(self) -> str: 

155 """Name of the field.""" 

156 raise NotImplementedError() 

157 

158 @property 

159 @abstractmethod 

160 def nullable(self) -> bool: 

161 """Whether the field permits null or `None` values.""" 

162 raise NotImplementedError() 

163 

164 @property 

165 @abstractmethod 

166 def data_type(self) -> pa.DataType: 

167 """Arrow data type for the field this converter produces.""" 

168 raise NotImplementedError() 

169 

170 @property 

171 def field(self) -> pa.Field: 

172 """Arrow field this converter produces.""" 

173 return pa.field(self.name, self.data_type, self.nullable) 

174 

175 def dictionary_encoded(self) -> ToArrow: 

176 """Return a new converter with the same name and type, but using 

177 dictionary encoding (to 32-bit integers) to compress duplicate values. 

178 """ 

179 return _ToArrowDictionary(self) 

180 

181 @abstractmethod 

182 def append(self, value: Any, column: list[Any]) -> None: 

183 """Append an object's arrow representation to a `list`. 

184 

185 Parameters 

186 ---------- 

187 value : `object` 

188 Original value to be converted to a row in an Arrow column. 

189 column : `list` 

190 List of values to append to. The type of value to append is 

191 implementation-defined; the only requirement is that `finish` be 

192 able to handle this `list` later. 

193 """ 

194 raise NotImplementedError() 

195 

196 @abstractmethod 

197 def finish(self, column: list[Any]) -> pa.Array: 

198 """Convert a list of values constructed via `append` into an Arrow 

199 array. 

200 

201 Parameters 

202 ---------- 

203 column : `list` 

204 List of column values populated by `append`. 

205 """ 

206 raise NotImplementedError() 

207 

208 

209class _ToArrowPrimitive(ToArrow): 

210 """`ToArrow` implementation for primitive types supported direct by Arrow. 

211 

212 Should be constructed via the `ToArrow.for_primitive` factory method. 

213 """ 

214 

215 def __init__(self, name: str, data_type: pa.DataType, nullable: bool): 

216 self._name = name 

217 self._data_type = data_type 

218 self._nullable = nullable 

219 

220 @property 

221 def name(self) -> str: 

222 # Docstring inherited. 

223 return self._name 

224 

225 @property 

226 def nullable(self) -> bool: 

227 # Docstring inherited. 

228 return self._nullable 

229 

230 @property 

231 def data_type(self) -> pa.DataType: 

232 # Docstring inherited. 

233 return self._data_type 

234 

235 def append(self, value: Any, column: list[Any]) -> None: 

236 # Docstring inherited. 

237 column.append(value) 

238 

239 def finish(self, column: list[Any]) -> pa.Array: 

240 # Docstring inherited. 

241 return pa.array(column, self._data_type) 

242 

243 

244class _ToArrowDictionary(ToArrow): 

245 """`ToArrow` implementation for Arrow dictionary fields. 

246 

247 Should be constructed via the `ToArrow.dictionary_encoded` factory method. 

248 """ 

249 

250 def __init__(self, to_arrow_value: ToArrow): 

251 self._to_arrow_value = to_arrow_value 

252 

253 @property 

254 def name(self) -> str: 

255 # Docstring inherited. 

256 return self._to_arrow_value.name 

257 

258 @property 

259 def nullable(self) -> bool: 

260 # Docstring inherited. 

261 return self._to_arrow_value.nullable 

262 

263 @property 

264 def data_type(self) -> pa.DataType: 

265 # Docstring inherited. 

266 # We hard-code int32 as the index type here because that's what 

267 # the pa.Arrow.dictionary_encode() method does. 

268 return pa.dictionary(pa.int32(), self._to_arrow_value.data_type) 

269 

270 def append(self, value: Any, column: list[Any]) -> None: 

271 # Docstring inherited. 

272 self._to_arrow_value.append(value, column) 

273 

274 def finish(self, column: list[Any]) -> pa.Array: 

275 # Docstring inherited. 

276 return self._to_arrow_value.finish(column).dictionary_encode() 

277 

278 

279class _ToArrowUUID(ToArrow): 

280 """`ToArrow` implementation for `uuid.UUID` fields. 

281 

282 Should be constructed via the `ToArrow.for_uuid` factory method. 

283 """ 

284 

285 def __init__(self, name: str, nullable: bool): 

286 self._name = name 

287 self._nullable = nullable 

288 

289 storage_type: ClassVar[pa.DataType] = pa.binary(16) 

290 

291 @property 

292 def name(self) -> str: 

293 # Docstring inherited. 

294 return self._name 

295 

296 @property 

297 def nullable(self) -> bool: 

298 # Docstring inherited. 

299 return self._nullable 

300 

301 @property 

302 def data_type(self) -> pa.DataType: 

303 # Docstring inherited. 

304 return UUIDArrowType() 

305 

306 def append(self, value: uuid.UUID | None, column: list[bytes | None]) -> None: 

307 # Docstring inherited. 

308 column.append(value.bytes if value is not None else None) 

309 

310 def finish(self, column: list[Any]) -> pa.Array: 

311 # Docstring inherited. 

312 storage_array = pa.array(column, self.storage_type) 

313 return pa.ExtensionArray.from_storage(UUIDArrowType(), storage_array) 

314 

315 

316class _ToArrowRegion(ToArrow): 

317 """`ToArrow` implementation for `lsst.sphgeom.Region` fields. 

318 

319 Should be constructed via the `ToArrow.for_region` factory method. 

320 """ 

321 

322 def __init__(self, name: str, nullable: bool): 

323 self._name = name 

324 self._nullable = nullable 

325 

326 storage_type: ClassVar[pa.DataType] = pa.binary() 

327 

328 @property 

329 def name(self) -> str: 

330 # Docstring inherited. 

331 return self._name 

332 

333 @property 

334 def nullable(self) -> bool: 

335 # Docstring inherited. 

336 return self._nullable 

337 

338 @property 

339 def data_type(self) -> pa.DataType: 

340 # Docstring inherited. 

341 return RegionArrowType() 

342 

343 def append(self, value: Region | None, column: list[bytes | None]) -> None: 

344 # Docstring inherited. 

345 column.append(value.encode() if value is not None else None) 

346 

347 def finish(self, column: list[Any]) -> pa.Array: 

348 # Docstring inherited. 

349 storage_array = pa.array(column, self.storage_type) 

350 return pa.ExtensionArray.from_storage(RegionArrowType(), storage_array) 

351 

352 

353class _ToArrowTimespan(ToArrow): 

354 """`ToArrow` implementation for `lsst.daf.butler.timespan` fields. 

355 

356 Should be constructed via the `ToArrow.for_timespan` factory method. 

357 """ 

358 

359 def __init__(self, name: str, nullable: bool): 

360 self._name = name 

361 self._nullable = nullable 

362 

363 storage_type: ClassVar[pa.DataType] = pa.struct( 

364 [ 

365 pa.field("begin_nsec", pa.int64(), nullable=False), 

366 pa.field("end_nsec", pa.int64(), nullable=False), 

367 ] 

368 ) 

369 

370 @property 

371 def name(self) -> str: 

372 # Docstring inherited. 

373 return self._name 

374 

375 @property 

376 def nullable(self) -> bool: 

377 # Docstring inherited. 

378 return self._nullable 

379 

380 @property 

381 def data_type(self) -> pa.DataType: 

382 # Docstring inherited. 

383 return TimespanArrowType() 

384 

385 def append(self, value: Timespan | None, column: list[pa.StructScalar | None]) -> None: 

386 # Docstring inherited. 

387 column.append({"begin_nsec": value.nsec[0], "end_nsec": value.nsec[1]} if value is not None else None) 

388 

389 def finish(self, column: list[Any]) -> pa.Array: 

390 # Docstring inherited. 

391 storage_array = pa.array(column, self.storage_type) 

392 return pa.ExtensionArray.from_storage(TimespanArrowType(), storage_array) 

393 

394 

395class _ToArrowDateTime(ToArrow): 

396 """`ToArrow` implementation for `astropy.time.Time` fields. 

397 

398 Should be constructed via the `ToArrow.for_datetime` factory method. 

399 """ 

400 

401 def __init__(self, name: str, nullable: bool): 

402 self._name = name 

403 self._nullable = nullable 

404 

405 storage_type: ClassVar[pa.DataType] = pa.int64() 

406 

407 @property 

408 def name(self) -> str: 

409 # Docstring inherited. 

410 return self._name 

411 

412 @property 

413 def nullable(self) -> bool: 

414 # Docstring inherited. 

415 return self._nullable 

416 

417 @property 

418 def data_type(self) -> pa.DataType: 

419 # Docstring inherited. 

420 return DateTimeArrowType() 

421 

422 def append(self, value: astropy.time.Time | None, column: list[int | None]) -> None: 

423 # Docstring inherited. 

424 column.append(TimeConverter().astropy_to_nsec(value) if value is not None else None) 

425 

426 def finish(self, column: list[Any]) -> pa.Array: 

427 # Docstring inherited. 

428 storage_array = pa.array(column, self.storage_type) 

429 return pa.ExtensionArray.from_storage(DateTimeArrowType(), storage_array) 

430 

431 

432@final 

433class UUIDArrowType(pa.ExtensionType): 

434 """An Arrow extension type for `astropy.time.Time`, stored as TAI 

435 nanoseconds since 1970-01-01. 

436 """ 

437 

438 def __init__(self) -> None: 

439 super().__init__(_ToArrowTimespan.storage_type, "astropy.time.Time") 

440 

441 def __arrow_ext_serialize__(self) -> bytes: 

442 return b"" 

443 

444 @classmethod 

445 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> UUIDArrowType: 

446 return cls() 

447 

448 def __arrow_ext_scalar_class__(self) -> type[UUIDArrowScalar]: 

449 return UUIDArrowScalar 

450 

451 

452@final 

453class UUIDArrowScalar(pa.ExtensionScalar): 

454 """An Arrow scalar type for `uuid.UUID`. 

455 

456 Use the standard `as_py` method to convert to an actual `uuid.UUID` 

457 instance. 

458 """ 

459 

460 def as_py(self) -> astropy.time.Time: 

461 return uuid.UUID(bytes=self.value.as_py()) 

462 

463 

464@final 

465class RegionArrowType(pa.ExtensionType): 

466 """An Arrow extension type for lsst.sphgeom.Region.""" 

467 

468 def __init__(self) -> None: 

469 super().__init__(_ToArrowRegion.storage_type, "lsst.sphgeom.Region") 

470 

471 def __arrow_ext_serialize__(self) -> bytes: 

472 return b"" 

473 

474 @classmethod 

475 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> RegionArrowType: 

476 return cls() 

477 

478 def __arrow_ext_scalar_class__(self) -> type[RegionArrowScalar]: 

479 return RegionArrowScalar 

480 

481 

482@final 

483class RegionArrowScalar(pa.ExtensionScalar): 

484 """An Arrow scalar type for `lsst.sphgeom.Region`. 

485 

486 Use the standard `as_py` method to convert to an actual region. 

487 """ 

488 

489 def as_py(self) -> Region: 

490 return Region.decode(self.value.as_py()) 

491 

492 

493@final 

494class TimespanArrowType(pa.ExtensionType): 

495 """An Arrow extension type for lsst.daf.butler.Timespan.""" 

496 

497 def __init__(self) -> None: 

498 super().__init__(_ToArrowTimespan.storage_type, "lsst.daf.butler.Timespan") 

499 

500 def __arrow_ext_serialize__(self) -> bytes: 

501 return b"" 

502 

503 @classmethod 

504 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> TimespanArrowType: 

505 return cls() 

506 

507 def __arrow_ext_scalar_class__(self) -> type[TimespanArrowScalar]: 

508 return TimespanArrowScalar 

509 

510 

511@final 

512class TimespanArrowScalar(pa.ExtensionScalar): 

513 """An Arrow scalar type for `lsst.daf.butler.Timespan`. 

514 

515 Use the standard `as_py` method to convert to an actual timespan. 

516 """ 

517 

518 def as_py(self) -> Timespan: 

519 return Timespan( 

520 None, None, _nsec=(self.value["begin_nsec"].as_py(), self.value["begin_nsec"].as_py()) 

521 ) 

522 

523 

524@final 

525class DateTimeArrowType(pa.ExtensionType): 

526 """An Arrow extension type for `astropy.time.Time`, stored as TAI 

527 nanoseconds since 1970-01-01. 

528 """ 

529 

530 def __init__(self) -> None: 

531 super().__init__(_ToArrowTimespan.storage_type, "astropy.time.Time") 

532 

533 def __arrow_ext_serialize__(self) -> bytes: 

534 return b"" 

535 

536 @classmethod 

537 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> DateTimeArrowType: 

538 return cls() 

539 

540 def __arrow_ext_scalar_class__(self) -> type[DateTimeArrowScalar]: 

541 return DateTimeArrowScalar 

542 

543 

544@final 

545class DateTimeArrowScalar(pa.ExtensionScalar): 

546 """An Arrow scalar type for `astropy.time.Time`, stored as TAI 

547 nanoseconds since 1970-01-01. 

548 

549 Use the standard `as_py` method to convert to an actual `astropy.time.Time` 

550 instance. 

551 """ 

552 

553 def as_py(self) -> astropy.time.Time: 

554 return TimeConverter().nsec_to_astropy(self.value.as_py()) 

555 

556 

557pa.register_extension_type(RegionArrowType()) 

558pa.register_extension_type(TimespanArrowType()) 

559pa.register_extension_type(DateTimeArrowType())