Coverage for python/lsst/daf/butler/arrow_utils.py: 78%
217 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-05 02:53 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-05 02:53 -0700
1# This file is part of butler4.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = (
31 "ToArrow",
32 "RegionArrowType",
33 "RegionArrowScalar",
34 "TimespanArrowType",
35 "TimespanArrowScalar",
36 "DateTimeArrowType",
37 "DateTimeArrowScalar",
38 "UUIDArrowType",
39 "UUIDArrowScalar",
40)
42import uuid
43from abc import ABC, abstractmethod
44from typing import Any, ClassVar, final
46import astropy.time
47import pyarrow as pa
48from lsst.sphgeom import Region
50from ._timespan import Timespan
51from .time_utils import TimeConverter
54class ToArrow(ABC):
55 """An abstract interface for converting objects to an Arrow field of the
56 appropriate type.
57 """
59 @staticmethod
60 def for_primitive(name: str, data_type: pa.DataType, nullable: bool) -> ToArrow:
61 """Return a converter for a primitive type already supported by Arrow.
63 Parameters
64 ----------
65 name : `str`
66 Name of the schema field.
67 data_type : `pyarrow.DataType`
68 Arrow data type object.
69 nullable : `bool`
70 Whether the field should permit null or `None` values.
72 Returns
73 -------
74 to_arrow : `ToArrow`
75 Converter instance.
76 """
77 return _ToArrowPrimitive(name, data_type, nullable)
79 @staticmethod
80 def for_uuid(name: str, nullable: bool = True) -> ToArrow:
81 """Return a converter for `uuid.UUID`.
83 Parameters
84 ----------
85 name : `str`
86 Name of the schema field.
87 nullable : `bool`
88 Whether the field should permit null or `None` values.
90 Returns
91 -------
92 to_arrow : `ToArrow`
93 Converter instance.
94 """
95 return _ToArrowUUID(name, nullable)
97 @staticmethod
98 def for_region(name: str, nullable: bool = True) -> ToArrow:
99 """Return a converter for `lsst.sphgeom.Region`.
101 Parameters
102 ----------
103 name : `str`
104 Name of the schema field.
105 nullable : `bool`
106 Whether the field should permit null or `None` values.
108 Returns
109 -------
110 to_arrow : `ToArrow`
111 Converter instance.
112 """
113 return _ToArrowRegion(name, nullable)
115 @staticmethod
116 def for_timespan(name: str, nullable: bool = True) -> ToArrow:
117 """Return a converter for `lsst.daf.butler.Timespan`.
119 Parameters
120 ----------
121 name : `str`
122 Name of the schema field.
123 nullable : `bool`
124 Whether the field should permit null or `None` values.
126 Returns
127 -------
128 to_arrow : `ToArrow`
129 Converter instance.
130 """
131 return _ToArrowTimespan(name, nullable)
133 @staticmethod
134 def for_datetime(name: str, nullable: bool = True) -> ToArrow:
135 """Return a converter for `astropy.time.Time`, stored as TAI
136 nanoseconds since 1970-01-01.
138 Parameters
139 ----------
140 name : `str`
141 Name of the schema field.
142 nullable : `bool`
143 Whether the field should permit null or `None` values.
145 Returns
146 -------
147 to_arrow : `ToArrow`
148 Converter instance.
149 """
150 return _ToArrowDateTime(name, nullable)
152 @property
153 @abstractmethod
154 def name(self) -> str:
155 """Name of the field."""
156 raise NotImplementedError()
158 @property
159 @abstractmethod
160 def nullable(self) -> bool:
161 """Whether the field permits null or `None` values."""
162 raise NotImplementedError()
164 @property
165 @abstractmethod
166 def data_type(self) -> pa.DataType:
167 """Arrow data type for the field this converter produces."""
168 raise NotImplementedError()
170 @property
171 def field(self) -> pa.Field:
172 """Arrow field this converter produces."""
173 return pa.field(self.name, self.data_type, self.nullable)
175 def dictionary_encoded(self) -> ToArrow:
176 """Return a new converter with the same name and type, but using
177 dictionary encoding (to 32-bit integers) to compress duplicate values.
178 """
179 return _ToArrowDictionary(self)
181 @abstractmethod
182 def append(self, value: Any, column: list[Any]) -> None:
183 """Append an object's arrow representation to a `list`.
185 Parameters
186 ----------
187 value : `object`
188 Original value to be converted to a row in an Arrow column.
189 column : `list`
190 List of values to append to. The type of value to append is
191 implementation-defined; the only requirement is that `finish` be
192 able to handle this `list` later.
193 """
194 raise NotImplementedError()
196 @abstractmethod
197 def finish(self, column: list[Any]) -> pa.Array:
198 """Convert a list of values constructed via `append` into an Arrow
199 array.
201 Parameters
202 ----------
203 column : `list`
204 List of column values populated by `append`.
205 """
206 raise NotImplementedError()
209class _ToArrowPrimitive(ToArrow):
210 """`ToArrow` implementation for primitive types supported direct by Arrow.
212 Should be constructed via the `ToArrow.for_primitive` factory method.
213 """
215 def __init__(self, name: str, data_type: pa.DataType, nullable: bool):
216 self._name = name
217 self._data_type = data_type
218 self._nullable = nullable
220 @property
221 def name(self) -> str:
222 # Docstring inherited.
223 return self._name
225 @property
226 def nullable(self) -> bool:
227 # Docstring inherited.
228 return self._nullable
230 @property
231 def data_type(self) -> pa.DataType:
232 # Docstring inherited.
233 return self._data_type
235 def append(self, value: Any, column: list[Any]) -> None:
236 # Docstring inherited.
237 column.append(value)
239 def finish(self, column: list[Any]) -> pa.Array:
240 # Docstring inherited.
241 return pa.array(column, self._data_type)
244class _ToArrowDictionary(ToArrow):
245 """`ToArrow` implementation for Arrow dictionary fields.
247 Should be constructed via the `ToArrow.dictionary_encoded` factory method.
248 """
250 def __init__(self, to_arrow_value: ToArrow):
251 self._to_arrow_value = to_arrow_value
253 @property
254 def name(self) -> str:
255 # Docstring inherited.
256 return self._to_arrow_value.name
258 @property
259 def nullable(self) -> bool:
260 # Docstring inherited.
261 return self._to_arrow_value.nullable
263 @property
264 def data_type(self) -> pa.DataType:
265 # Docstring inherited.
266 # We hard-code int32 as the index type here because that's what
267 # the pa.Arrow.dictionary_encode() method does.
268 return pa.dictionary(pa.int32(), self._to_arrow_value.data_type)
270 def append(self, value: Any, column: list[Any]) -> None:
271 # Docstring inherited.
272 self._to_arrow_value.append(value, column)
274 def finish(self, column: list[Any]) -> pa.Array:
275 # Docstring inherited.
276 return self._to_arrow_value.finish(column).dictionary_encode()
279class _ToArrowUUID(ToArrow):
280 """`ToArrow` implementation for `uuid.UUID` fields.
282 Should be constructed via the `ToArrow.for_uuid` factory method.
283 """
285 def __init__(self, name: str, nullable: bool):
286 self._name = name
287 self._nullable = nullable
289 storage_type: ClassVar[pa.DataType] = pa.binary(16)
291 @property
292 def name(self) -> str:
293 # Docstring inherited.
294 return self._name
296 @property
297 def nullable(self) -> bool:
298 # Docstring inherited.
299 return self._nullable
301 @property
302 def data_type(self) -> pa.DataType:
303 # Docstring inherited.
304 return UUIDArrowType()
306 def append(self, value: uuid.UUID | None, column: list[bytes | None]) -> None:
307 # Docstring inherited.
308 column.append(value.bytes if value is not None else None)
310 def finish(self, column: list[Any]) -> pa.Array:
311 # Docstring inherited.
312 storage_array = pa.array(column, self.storage_type)
313 return pa.ExtensionArray.from_storage(UUIDArrowType(), storage_array)
316class _ToArrowRegion(ToArrow):
317 """`ToArrow` implementation for `lsst.sphgeom.Region` fields.
319 Should be constructed via the `ToArrow.for_region` factory method.
320 """
322 def __init__(self, name: str, nullable: bool):
323 self._name = name
324 self._nullable = nullable
326 storage_type: ClassVar[pa.DataType] = pa.binary()
328 @property
329 def name(self) -> str:
330 # Docstring inherited.
331 return self._name
333 @property
334 def nullable(self) -> bool:
335 # Docstring inherited.
336 return self._nullable
338 @property
339 def data_type(self) -> pa.DataType:
340 # Docstring inherited.
341 return RegionArrowType()
343 def append(self, value: Region | None, column: list[bytes | None]) -> None:
344 # Docstring inherited.
345 column.append(value.encode() if value is not None else None)
347 def finish(self, column: list[Any]) -> pa.Array:
348 # Docstring inherited.
349 storage_array = pa.array(column, self.storage_type)
350 return pa.ExtensionArray.from_storage(RegionArrowType(), storage_array)
353class _ToArrowTimespan(ToArrow):
354 """`ToArrow` implementation for `lsst.daf.butler.timespan` fields.
356 Should be constructed via the `ToArrow.for_timespan` factory method.
357 """
359 def __init__(self, name: str, nullable: bool):
360 self._name = name
361 self._nullable = nullable
363 storage_type: ClassVar[pa.DataType] = pa.struct(
364 [
365 pa.field("begin_nsec", pa.int64(), nullable=False),
366 pa.field("end_nsec", pa.int64(), nullable=False),
367 ]
368 )
370 @property
371 def name(self) -> str:
372 # Docstring inherited.
373 return self._name
375 @property
376 def nullable(self) -> bool:
377 # Docstring inherited.
378 return self._nullable
380 @property
381 def data_type(self) -> pa.DataType:
382 # Docstring inherited.
383 return TimespanArrowType()
385 def append(self, value: Timespan | None, column: list[pa.StructScalar | None]) -> None:
386 # Docstring inherited.
387 column.append(
388 {"begin_nsec": value._nsec[0], "end_nsec": value._nsec[1]} if value is not None else None
389 )
391 def finish(self, column: list[Any]) -> pa.Array:
392 # Docstring inherited.
393 storage_array = pa.array(column, self.storage_type)
394 return pa.ExtensionArray.from_storage(TimespanArrowType(), storage_array)
397class _ToArrowDateTime(ToArrow):
398 """`ToArrow` implementation for `astropy.time.Time` fields.
400 Should be constructed via the `ToArrow.for_datetime` factory method.
401 """
403 def __init__(self, name: str, nullable: bool):
404 self._name = name
405 self._nullable = nullable
407 storage_type: ClassVar[pa.DataType] = pa.int64()
409 @property
410 def name(self) -> str:
411 # Docstring inherited.
412 return self._name
414 @property
415 def nullable(self) -> bool:
416 # Docstring inherited.
417 return self._nullable
419 @property
420 def data_type(self) -> pa.DataType:
421 # Docstring inherited.
422 return DateTimeArrowType()
424 def append(self, value: astropy.time.Time | None, column: list[int | None]) -> None:
425 # Docstring inherited.
426 column.append(TimeConverter().astropy_to_nsec(value) if value is not None else None)
428 def finish(self, column: list[Any]) -> pa.Array:
429 # Docstring inherited.
430 storage_array = pa.array(column, self.storage_type)
431 return pa.ExtensionArray.from_storage(DateTimeArrowType(), storage_array)
434@final
435class UUIDArrowType(pa.ExtensionType):
436 """An Arrow extension type for `astropy.time.Time`, stored as TAI
437 nanoseconds since 1970-01-01.
438 """
440 def __init__(self) -> None:
441 super().__init__(_ToArrowTimespan.storage_type, "astropy.time.Time")
443 def __arrow_ext_serialize__(self) -> bytes:
444 return b""
446 @classmethod
447 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> UUIDArrowType:
448 return cls()
450 def __arrow_ext_scalar_class__(self) -> type[UUIDArrowScalar]:
451 return UUIDArrowScalar
454@final
455class UUIDArrowScalar(pa.ExtensionScalar):
456 """An Arrow scalar type for `uuid.UUID`.
458 Use the standard `as_py` method to convert to an actual `uuid.UUID`
459 instance.
460 """
462 def as_py(self) -> astropy.time.Time:
463 return uuid.UUID(bytes=self.value.as_py())
466@final
467class RegionArrowType(pa.ExtensionType):
468 """An Arrow extension type for lsst.sphgeom.Region."""
470 def __init__(self) -> None:
471 super().__init__(_ToArrowRegion.storage_type, "lsst.sphgeom.Region")
473 def __arrow_ext_serialize__(self) -> bytes:
474 return b""
476 @classmethod
477 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> RegionArrowType:
478 return cls()
480 def __arrow_ext_scalar_class__(self) -> type[RegionArrowScalar]:
481 return RegionArrowScalar
484@final
485class RegionArrowScalar(pa.ExtensionScalar):
486 """An Arrow scalar type for `lsst.sphgeom.Region`.
488 Use the standard `as_py` method to convert to an actual region.
489 """
491 def as_py(self) -> Region:
492 return Region.decode(self.value.as_py())
495@final
496class TimespanArrowType(pa.ExtensionType):
497 """An Arrow extension type for lsst.daf.butler.Timespan."""
499 def __init__(self) -> None:
500 super().__init__(_ToArrowTimespan.storage_type, "lsst.daf.butler.Timespan")
502 def __arrow_ext_serialize__(self) -> bytes:
503 return b""
505 @classmethod
506 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> TimespanArrowType:
507 return cls()
509 def __arrow_ext_scalar_class__(self) -> type[TimespanArrowScalar]:
510 return TimespanArrowScalar
513@final
514class TimespanArrowScalar(pa.ExtensionScalar):
515 """An Arrow scalar type for `lsst.daf.butler.Timespan`.
517 Use the standard `as_py` method to convert to an actual timespan.
518 """
520 def as_py(self) -> Timespan:
521 return Timespan(
522 None, None, _nsec=(self.value["begin_nsec"].as_py(), self.value["begin_nsec"].as_py())
523 )
526@final
527class DateTimeArrowType(pa.ExtensionType):
528 """An Arrow extension type for `astropy.time.Time`, stored as TAI
529 nanoseconds since 1970-01-01.
530 """
532 def __init__(self) -> None:
533 super().__init__(_ToArrowTimespan.storage_type, "astropy.time.Time")
535 def __arrow_ext_serialize__(self) -> bytes:
536 return b""
538 @classmethod
539 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> DateTimeArrowType:
540 return cls()
542 def __arrow_ext_scalar_class__(self) -> type[DateTimeArrowScalar]:
543 return DateTimeArrowScalar
546@final
547class DateTimeArrowScalar(pa.ExtensionScalar):
548 """An Arrow scalar type for `astropy.time.Time`, stored as TAI
549 nanoseconds since 1970-01-01.
551 Use the standard `as_py` method to convert to an actual `astropy.time.Time`
552 instance.
553 """
555 def as_py(self) -> astropy.time.Time:
556 return TimeConverter().nsec_to_astropy(self.value.as_py())
559pa.register_extension_type(RegionArrowType())
560pa.register_extension_type(TimespanArrowType())
561pa.register_extension_type(DateTimeArrowType())