Coverage for python/lsst/daf/butler/arrow_utils.py: 78%

217 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-30 09:59 +0000

1# This file is part of butler4. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ( 

31 "ToArrow", 

32 "RegionArrowType", 

33 "RegionArrowScalar", 

34 "TimespanArrowType", 

35 "TimespanArrowScalar", 

36 "DateTimeArrowType", 

37 "DateTimeArrowScalar", 

38 "UUIDArrowType", 

39 "UUIDArrowScalar", 

40) 

41 

42import uuid 

43from abc import ABC, abstractmethod 

44from typing import Any, ClassVar, final 

45 

46import astropy.time 

47import pyarrow as pa 

48from lsst.sphgeom import Region 

49 

50from ._timespan import Timespan 

51from .time_utils import TimeConverter 

52 

53 

54class ToArrow(ABC): 

55 """An abstract interface for converting objects to an Arrow field of the 

56 appropriate type. 

57 """ 

58 

59 @staticmethod 

60 def for_primitive(name: str, data_type: pa.DataType, nullable: bool) -> ToArrow: 

61 """Return a converter for a primitive type already supported by Arrow. 

62 

63 Parameters 

64 ---------- 

65 name : `str` 

66 Name of the schema field. 

67 data_type : `pyarrow.DataType` 

68 Arrow data type object. 

69 nullable : `bool` 

70 Whether the field should permit null or `None` values. 

71 

72 Returns 

73 ------- 

74 to_arrow : `ToArrow` 

75 Converter instance. 

76 """ 

77 return _ToArrowPrimitive(name, data_type, nullable) 

78 

79 @staticmethod 

80 def for_uuid(name: str, nullable: bool = True) -> ToArrow: 

81 """Return a converter for `uuid.UUID`. 

82 

83 Parameters 

84 ---------- 

85 name : `str` 

86 Name of the schema field. 

87 nullable : `bool` 

88 Whether the field should permit null or `None` values. 

89 

90 Returns 

91 ------- 

92 to_arrow : `ToArrow` 

93 Converter instance. 

94 """ 

95 return _ToArrowUUID(name, nullable) 

96 

97 @staticmethod 

98 def for_region(name: str, nullable: bool = True) -> ToArrow: 

99 """Return a converter for `lsst.sphgeom.Region`. 

100 

101 Parameters 

102 ---------- 

103 name : `str` 

104 Name of the schema field. 

105 nullable : `bool` 

106 Whether the field should permit null or `None` values. 

107 

108 Returns 

109 ------- 

110 to_arrow : `ToArrow` 

111 Converter instance. 

112 """ 

113 return _ToArrowRegion(name, nullable) 

114 

115 @staticmethod 

116 def for_timespan(name: str, nullable: bool = True) -> ToArrow: 

117 """Return a converter for `lsst.daf.butler.Timespan`. 

118 

119 Parameters 

120 ---------- 

121 name : `str` 

122 Name of the schema field. 

123 nullable : `bool` 

124 Whether the field should permit null or `None` values. 

125 

126 Returns 

127 ------- 

128 to_arrow : `ToArrow` 

129 Converter instance. 

130 """ 

131 return _ToArrowTimespan(name, nullable) 

132 

133 @staticmethod 

134 def for_datetime(name: str, nullable: bool = True) -> ToArrow: 

135 """Return a converter for `astropy.time.Time`, stored as TAI 

136 nanoseconds since 1970-01-01. 

137 

138 Parameters 

139 ---------- 

140 name : `str` 

141 Name of the schema field. 

142 nullable : `bool` 

143 Whether the field should permit null or `None` values. 

144 

145 Returns 

146 ------- 

147 to_arrow : `ToArrow` 

148 Converter instance. 

149 """ 

150 return _ToArrowDateTime(name, nullable) 

151 

152 @property 

153 @abstractmethod 

154 def name(self) -> str: 

155 """Name of the field.""" 

156 raise NotImplementedError() 

157 

158 @property 

159 @abstractmethod 

160 def nullable(self) -> bool: 

161 """Whether the field permits null or `None` values.""" 

162 raise NotImplementedError() 

163 

164 @property 

165 @abstractmethod 

166 def data_type(self) -> pa.DataType: 

167 """Arrow data type for the field this converter produces.""" 

168 raise NotImplementedError() 

169 

170 @property 

171 def field(self) -> pa.Field: 

172 """Arrow field this converter produces.""" 

173 return pa.field(self.name, self.data_type, self.nullable) 

174 

175 def dictionary_encoded(self) -> ToArrow: 

176 """Return a new converter with the same name and type, but using 

177 dictionary encoding (to 32-bit integers) to compress duplicate values. 

178 """ 

179 return _ToArrowDictionary(self) 

180 

181 @abstractmethod 

182 def append(self, value: Any, column: list[Any]) -> None: 

183 """Append an object's arrow representation to a `list`. 

184 

185 Parameters 

186 ---------- 

187 value : `object` 

188 Original value to be converted to a row in an Arrow column. 

189 column : `list` 

190 List of values to append to. The type of value to append is 

191 implementation-defined; the only requirement is that `finish` be 

192 able to handle this `list` later. 

193 """ 

194 raise NotImplementedError() 

195 

196 @abstractmethod 

197 def finish(self, column: list[Any]) -> pa.Array: 

198 """Convert a list of values constructed via `append` into an Arrow 

199 array. 

200 

201 Parameters 

202 ---------- 

203 column : `list` 

204 List of column values populated by `append`. 

205 """ 

206 raise NotImplementedError() 

207 

208 

209class _ToArrowPrimitive(ToArrow): 

210 """`ToArrow` implementation for primitive types supported direct by Arrow. 

211 

212 Should be constructed via the `ToArrow.for_primitive` factory method. 

213 """ 

214 

215 def __init__(self, name: str, data_type: pa.DataType, nullable: bool): 

216 self._name = name 

217 self._data_type = data_type 

218 self._nullable = nullable 

219 

220 @property 

221 def name(self) -> str: 

222 # Docstring inherited. 

223 return self._name 

224 

225 @property 

226 def nullable(self) -> bool: 

227 # Docstring inherited. 

228 return self._nullable 

229 

230 @property 

231 def data_type(self) -> pa.DataType: 

232 # Docstring inherited. 

233 return self._data_type 

234 

235 def append(self, value: Any, column: list[Any]) -> None: 

236 # Docstring inherited. 

237 column.append(value) 

238 

239 def finish(self, column: list[Any]) -> pa.Array: 

240 # Docstring inherited. 

241 return pa.array(column, self._data_type) 

242 

243 

244class _ToArrowDictionary(ToArrow): 

245 """`ToArrow` implementation for Arrow dictionary fields. 

246 

247 Should be constructed via the `ToArrow.dictionary_encoded` factory method. 

248 """ 

249 

250 def __init__(self, to_arrow_value: ToArrow): 

251 self._to_arrow_value = to_arrow_value 

252 

253 @property 

254 def name(self) -> str: 

255 # Docstring inherited. 

256 return self._to_arrow_value.name 

257 

258 @property 

259 def nullable(self) -> bool: 

260 # Docstring inherited. 

261 return self._to_arrow_value.nullable 

262 

263 @property 

264 def data_type(self) -> pa.DataType: 

265 # Docstring inherited. 

266 # We hard-code int32 as the index type here because that's what 

267 # the pa.Arrow.dictionary_encode() method does. 

268 return pa.dictionary(pa.int32(), self._to_arrow_value.data_type) 

269 

270 def append(self, value: Any, column: list[Any]) -> None: 

271 # Docstring inherited. 

272 self._to_arrow_value.append(value, column) 

273 

274 def finish(self, column: list[Any]) -> pa.Array: 

275 # Docstring inherited. 

276 return self._to_arrow_value.finish(column).dictionary_encode() 

277 

278 

279class _ToArrowUUID(ToArrow): 

280 """`ToArrow` implementation for `uuid.UUID` fields. 

281 

282 Should be constructed via the `ToArrow.for_uuid` factory method. 

283 """ 

284 

285 def __init__(self, name: str, nullable: bool): 

286 self._name = name 

287 self._nullable = nullable 

288 

289 storage_type: ClassVar[pa.DataType] = pa.binary(16) 

290 

291 @property 

292 def name(self) -> str: 

293 # Docstring inherited. 

294 return self._name 

295 

296 @property 

297 def nullable(self) -> bool: 

298 # Docstring inherited. 

299 return self._nullable 

300 

301 @property 

302 def data_type(self) -> pa.DataType: 

303 # Docstring inherited. 

304 return UUIDArrowType() 

305 

306 def append(self, value: uuid.UUID | None, column: list[bytes | None]) -> None: 

307 # Docstring inherited. 

308 column.append(value.bytes if value is not None else None) 

309 

310 def finish(self, column: list[Any]) -> pa.Array: 

311 # Docstring inherited. 

312 storage_array = pa.array(column, self.storage_type) 

313 return pa.ExtensionArray.from_storage(UUIDArrowType(), storage_array) 

314 

315 

316class _ToArrowRegion(ToArrow): 

317 """`ToArrow` implementation for `lsst.sphgeom.Region` fields. 

318 

319 Should be constructed via the `ToArrow.for_region` factory method. 

320 """ 

321 

322 def __init__(self, name: str, nullable: bool): 

323 self._name = name 

324 self._nullable = nullable 

325 

326 storage_type: ClassVar[pa.DataType] = pa.binary() 

327 

328 @property 

329 def name(self) -> str: 

330 # Docstring inherited. 

331 return self._name 

332 

333 @property 

334 def nullable(self) -> bool: 

335 # Docstring inherited. 

336 return self._nullable 

337 

338 @property 

339 def data_type(self) -> pa.DataType: 

340 # Docstring inherited. 

341 return RegionArrowType() 

342 

343 def append(self, value: Region | None, column: list[bytes | None]) -> None: 

344 # Docstring inherited. 

345 column.append(value.encode() if value is not None else None) 

346 

347 def finish(self, column: list[Any]) -> pa.Array: 

348 # Docstring inherited. 

349 storage_array = pa.array(column, self.storage_type) 

350 return pa.ExtensionArray.from_storage(RegionArrowType(), storage_array) 

351 

352 

353class _ToArrowTimespan(ToArrow): 

354 """`ToArrow` implementation for `lsst.daf.butler.timespan` fields. 

355 

356 Should be constructed via the `ToArrow.for_timespan` factory method. 

357 """ 

358 

359 def __init__(self, name: str, nullable: bool): 

360 self._name = name 

361 self._nullable = nullable 

362 

363 storage_type: ClassVar[pa.DataType] = pa.struct( 

364 [ 

365 pa.field("begin_nsec", pa.int64(), nullable=False), 

366 pa.field("end_nsec", pa.int64(), nullable=False), 

367 ] 

368 ) 

369 

370 @property 

371 def name(self) -> str: 

372 # Docstring inherited. 

373 return self._name 

374 

375 @property 

376 def nullable(self) -> bool: 

377 # Docstring inherited. 

378 return self._nullable 

379 

380 @property 

381 def data_type(self) -> pa.DataType: 

382 # Docstring inherited. 

383 return TimespanArrowType() 

384 

385 def append(self, value: Timespan | None, column: list[pa.StructScalar | None]) -> None: 

386 # Docstring inherited. 

387 column.append( 

388 {"begin_nsec": value._nsec[0], "end_nsec": value._nsec[1]} if value is not None else None 

389 ) 

390 

391 def finish(self, column: list[Any]) -> pa.Array: 

392 # Docstring inherited. 

393 storage_array = pa.array(column, self.storage_type) 

394 return pa.ExtensionArray.from_storage(TimespanArrowType(), storage_array) 

395 

396 

397class _ToArrowDateTime(ToArrow): 

398 """`ToArrow` implementation for `astropy.time.Time` fields. 

399 

400 Should be constructed via the `ToArrow.for_datetime` factory method. 

401 """ 

402 

403 def __init__(self, name: str, nullable: bool): 

404 self._name = name 

405 self._nullable = nullable 

406 

407 storage_type: ClassVar[pa.DataType] = pa.int64() 

408 

409 @property 

410 def name(self) -> str: 

411 # Docstring inherited. 

412 return self._name 

413 

414 @property 

415 def nullable(self) -> bool: 

416 # Docstring inherited. 

417 return self._nullable 

418 

419 @property 

420 def data_type(self) -> pa.DataType: 

421 # Docstring inherited. 

422 return DateTimeArrowType() 

423 

424 def append(self, value: astropy.time.Time | None, column: list[int | None]) -> None: 

425 # Docstring inherited. 

426 column.append(TimeConverter().astropy_to_nsec(value) if value is not None else None) 

427 

428 def finish(self, column: list[Any]) -> pa.Array: 

429 # Docstring inherited. 

430 storage_array = pa.array(column, self.storage_type) 

431 return pa.ExtensionArray.from_storage(DateTimeArrowType(), storage_array) 

432 

433 

434@final 

435class UUIDArrowType(pa.ExtensionType): 

436 """An Arrow extension type for `astropy.time.Time`, stored as TAI 

437 nanoseconds since 1970-01-01. 

438 """ 

439 

440 def __init__(self) -> None: 

441 super().__init__(_ToArrowTimespan.storage_type, "astropy.time.Time") 

442 

443 def __arrow_ext_serialize__(self) -> bytes: 

444 return b"" 

445 

446 @classmethod 

447 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> UUIDArrowType: 

448 return cls() 

449 

450 def __arrow_ext_scalar_class__(self) -> type[UUIDArrowScalar]: 

451 return UUIDArrowScalar 

452 

453 

454@final 

455class UUIDArrowScalar(pa.ExtensionScalar): 

456 """An Arrow scalar type for `uuid.UUID`. 

457 

458 Use the standard `as_py` method to convert to an actual `uuid.UUID` 

459 instance. 

460 """ 

461 

462 def as_py(self) -> astropy.time.Time: 

463 return uuid.UUID(bytes=self.value.as_py()) 

464 

465 

466@final 

467class RegionArrowType(pa.ExtensionType): 

468 """An Arrow extension type for lsst.sphgeom.Region.""" 

469 

470 def __init__(self) -> None: 

471 super().__init__(_ToArrowRegion.storage_type, "lsst.sphgeom.Region") 

472 

473 def __arrow_ext_serialize__(self) -> bytes: 

474 return b"" 

475 

476 @classmethod 

477 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> RegionArrowType: 

478 return cls() 

479 

480 def __arrow_ext_scalar_class__(self) -> type[RegionArrowScalar]: 

481 return RegionArrowScalar 

482 

483 

484@final 

485class RegionArrowScalar(pa.ExtensionScalar): 

486 """An Arrow scalar type for `lsst.sphgeom.Region`. 

487 

488 Use the standard `as_py` method to convert to an actual region. 

489 """ 

490 

491 def as_py(self) -> Region: 

492 return Region.decode(self.value.as_py()) 

493 

494 

495@final 

496class TimespanArrowType(pa.ExtensionType): 

497 """An Arrow extension type for lsst.daf.butler.Timespan.""" 

498 

499 def __init__(self) -> None: 

500 super().__init__(_ToArrowTimespan.storage_type, "lsst.daf.butler.Timespan") 

501 

502 def __arrow_ext_serialize__(self) -> bytes: 

503 return b"" 

504 

505 @classmethod 

506 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> TimespanArrowType: 

507 return cls() 

508 

509 def __arrow_ext_scalar_class__(self) -> type[TimespanArrowScalar]: 

510 return TimespanArrowScalar 

511 

512 

513@final 

514class TimespanArrowScalar(pa.ExtensionScalar): 

515 """An Arrow scalar type for `lsst.daf.butler.Timespan`. 

516 

517 Use the standard `as_py` method to convert to an actual timespan. 

518 """ 

519 

520 def as_py(self) -> Timespan: 

521 return Timespan( 

522 None, None, _nsec=(self.value["begin_nsec"].as_py(), self.value["begin_nsec"].as_py()) 

523 ) 

524 

525 

526@final 

527class DateTimeArrowType(pa.ExtensionType): 

528 """An Arrow extension type for `astropy.time.Time`, stored as TAI 

529 nanoseconds since 1970-01-01. 

530 """ 

531 

532 def __init__(self) -> None: 

533 super().__init__(_ToArrowTimespan.storage_type, "astropy.time.Time") 

534 

535 def __arrow_ext_serialize__(self) -> bytes: 

536 return b"" 

537 

538 @classmethod 

539 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> DateTimeArrowType: 

540 return cls() 

541 

542 def __arrow_ext_scalar_class__(self) -> type[DateTimeArrowScalar]: 

543 return DateTimeArrowScalar 

544 

545 

546@final 

547class DateTimeArrowScalar(pa.ExtensionScalar): 

548 """An Arrow scalar type for `astropy.time.Time`, stored as TAI 

549 nanoseconds since 1970-01-01. 

550 

551 Use the standard `as_py` method to convert to an actual `astropy.time.Time` 

552 instance. 

553 """ 

554 

555 def as_py(self) -> astropy.time.Time: 

556 return TimeConverter().nsec_to_astropy(self.value.as_py()) 

557 

558 

559pa.register_extension_type(RegionArrowType()) 

560pa.register_extension_type(TimespanArrowType()) 

561pa.register_extension_type(DateTimeArrowType())