Coverage for python/lsst/daf/butler/arrow_utils.py: 78%
217 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-13 10:57 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-13 10:57 +0000
1# This file is part of butler4.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = (
31 "ToArrow",
32 "RegionArrowType",
33 "RegionArrowScalar",
34 "TimespanArrowType",
35 "TimespanArrowScalar",
36 "DateTimeArrowType",
37 "DateTimeArrowScalar",
38 "UUIDArrowType",
39 "UUIDArrowScalar",
40)
42import uuid
43from abc import ABC, abstractmethod
44from typing import Any, ClassVar, final
46import astropy.time
47import pyarrow as pa
48from lsst.sphgeom import Region
50from ._timespan import Timespan
51from .time_utils import TimeConverter
54class ToArrow(ABC):
55 """An abstract interface for converting objects to an Arrow field of the
56 appropriate type.
57 """
59 @staticmethod
60 def for_primitive(name: str, data_type: pa.DataType, nullable: bool) -> ToArrow:
61 """Return a converter for a primitive type already supported by Arrow.
63 Parameters
64 ----------
65 name : `str`
66 Name of the schema field.
67 data_type : `pyarrow.DataType`
68 Arrow data type object.
69 nullable : `bool`
70 Whether the field should permit null or `None` values.
72 Returns
73 -------
74 to_arrow : `ToArrow`
75 Converter instance.
76 """
77 return _ToArrowPrimitive(name, data_type, nullable)
79 @staticmethod
80 def for_uuid(name: str, nullable: bool = True) -> ToArrow:
81 """Return a converter for `uuid.UUID`.
83 Parameters
84 ----------
85 name : `str`
86 Name of the schema field.
87 nullable : `bool`
88 Whether the field should permit null or `None` values.
90 Returns
91 -------
92 to_arrow : `ToArrow`
93 Converter instance.
94 """
95 return _ToArrowUUID(name, nullable)
97 @staticmethod
98 def for_region(name: str, nullable: bool = True) -> ToArrow:
99 """Return a converter for `lsst.sphgeom.Region`.
101 Parameters
102 ----------
103 name : `str`
104 Name of the schema field.
105 nullable : `bool`
106 Whether the field should permit null or `None` values.
108 Returns
109 -------
110 to_arrow : `ToArrow`
111 Converter instance.
112 """
113 return _ToArrowRegion(name, nullable)
115 @staticmethod
116 def for_timespan(name: str, nullable: bool = True) -> ToArrow:
117 """Return a converter for `lsst.daf.butler.Timespan`.
119 Parameters
120 ----------
121 name : `str`
122 Name of the schema field.
123 nullable : `bool`
124 Whether the field should permit null or `None` values.
126 Returns
127 -------
128 to_arrow : `ToArrow`
129 Converter instance.
130 """
131 return _ToArrowTimespan(name, nullable)
133 @staticmethod
134 def for_datetime(name: str, nullable: bool = True) -> ToArrow:
135 """Return a converter for `astropy.time.Time`, stored as TAI
136 nanoseconds since 1970-01-01.
138 Parameters
139 ----------
140 name : `str`
141 Name of the schema field.
142 nullable : `bool`
143 Whether the field should permit null or `None` values.
145 Returns
146 -------
147 to_arrow : `ToArrow`
148 Converter instance.
149 """
150 return _ToArrowDateTime(name, nullable)
152 @property
153 @abstractmethod
154 def name(self) -> str:
155 """Name of the field."""
156 raise NotImplementedError()
158 @property
159 @abstractmethod
160 def nullable(self) -> bool:
161 """Whether the field permits null or `None` values."""
162 raise NotImplementedError()
164 @property
165 @abstractmethod
166 def data_type(self) -> pa.DataType:
167 """Arrow data type for the field this converter produces."""
168 raise NotImplementedError()
170 @property
171 def field(self) -> pa.Field:
172 """Arrow field this converter produces."""
173 return pa.field(self.name, self.data_type, self.nullable)
175 def dictionary_encoded(self) -> ToArrow:
176 """Return a new converter with the same name and type, but using
177 dictionary encoding (to 32-bit integers) to compress duplicate values.
178 """
179 return _ToArrowDictionary(self)
181 @abstractmethod
182 def append(self, value: Any, column: list[Any]) -> None:
183 """Append an object's arrow representation to a `list`.
185 Parameters
186 ----------
187 value : `object`
188 Original value to be converted to a row in an Arrow column.
189 column : `list`
190 List of values to append to. The type of value to append is
191 implementation-defined; the only requirement is that `finish` be
192 able to handle this `list` later.
193 """
194 raise NotImplementedError()
196 @abstractmethod
197 def finish(self, column: list[Any]) -> pa.Array:
198 """Convert a list of values constructed via `append` into an Arrow
199 array.
201 Parameters
202 ----------
203 column : `list`
204 List of column values populated by `append`.
205 """
206 raise NotImplementedError()
209class _ToArrowPrimitive(ToArrow):
210 """`ToArrow` implementation for primitive types supported direct by Arrow.
212 Should be constructed via the `ToArrow.for_primitive` factory method.
213 """
215 def __init__(self, name: str, data_type: pa.DataType, nullable: bool):
216 self._name = name
217 self._data_type = data_type
218 self._nullable = nullable
220 @property
221 def name(self) -> str:
222 # Docstring inherited.
223 return self._name
225 @property
226 def nullable(self) -> bool:
227 # Docstring inherited.
228 return self._nullable
230 @property
231 def data_type(self) -> pa.DataType:
232 # Docstring inherited.
233 return self._data_type
235 def append(self, value: Any, column: list[Any]) -> None:
236 # Docstring inherited.
237 column.append(value)
239 def finish(self, column: list[Any]) -> pa.Array:
240 # Docstring inherited.
241 return pa.array(column, self._data_type)
244class _ToArrowDictionary(ToArrow):
245 """`ToArrow` implementation for Arrow dictionary fields.
247 Should be constructed via the `ToArrow.dictionary_encoded` factory method.
248 """
250 def __init__(self, to_arrow_value: ToArrow):
251 self._to_arrow_value = to_arrow_value
253 @property
254 def name(self) -> str:
255 # Docstring inherited.
256 return self._to_arrow_value.name
258 @property
259 def nullable(self) -> bool:
260 # Docstring inherited.
261 return self._to_arrow_value.nullable
263 @property
264 def data_type(self) -> pa.DataType:
265 # Docstring inherited.
266 # We hard-code int32 as the index type here because that's what
267 # the pa.Arrow.dictionary_encode() method does.
268 return pa.dictionary(pa.int32(), self._to_arrow_value.data_type)
270 def append(self, value: Any, column: list[Any]) -> None:
271 # Docstring inherited.
272 self._to_arrow_value.append(value, column)
274 def finish(self, column: list[Any]) -> pa.Array:
275 # Docstring inherited.
276 return self._to_arrow_value.finish(column).dictionary_encode()
279class _ToArrowUUID(ToArrow):
280 """`ToArrow` implementation for `uuid.UUID` fields.
282 Should be constructed via the `ToArrow.for_uuid` factory method.
283 """
285 def __init__(self, name: str, nullable: bool):
286 self._name = name
287 self._nullable = nullable
289 storage_type: ClassVar[pa.DataType] = pa.binary(16)
291 @property
292 def name(self) -> str:
293 # Docstring inherited.
294 return self._name
296 @property
297 def nullable(self) -> bool:
298 # Docstring inherited.
299 return self._nullable
301 @property
302 def data_type(self) -> pa.DataType:
303 # Docstring inherited.
304 return UUIDArrowType()
306 def append(self, value: uuid.UUID, column: list[bytes]) -> None:
307 # Docstring inherited.
308 column.append(value.bytes)
310 def finish(self, column: list[Any]) -> pa.Array:
311 # Docstring inherited.
312 storage_array = pa.array(column, self.storage_type)
313 return pa.ExtensionArray.from_storage(UUIDArrowType(), storage_array)
316class _ToArrowRegion(ToArrow):
317 """`ToArrow` implementation for `lsst.sphgeom.Region` fields.
319 Should be constructed via the `ToArrow.for_region` factory method.
320 """
322 def __init__(self, name: str, nullable: bool):
323 self._name = name
324 self._nullable = nullable
326 storage_type: ClassVar[pa.DataType] = pa.binary()
328 @property
329 def name(self) -> str:
330 # Docstring inherited.
331 return self._name
333 @property
334 def nullable(self) -> bool:
335 # Docstring inherited.
336 return self._nullable
338 @property
339 def data_type(self) -> pa.DataType:
340 # Docstring inherited.
341 return RegionArrowType()
343 def append(self, value: Region, column: list[bytes]) -> None:
344 # Docstring inherited.
345 column.append(value.encode())
347 def finish(self, column: list[Any]) -> pa.Array:
348 # Docstring inherited.
349 storage_array = pa.array(column, self.storage_type)
350 return pa.ExtensionArray.from_storage(RegionArrowType(), storage_array)
353class _ToArrowTimespan(ToArrow):
354 """`ToArrow` implementation for `lsst.daf.butler.timespan` fields.
356 Should be constructed via the `ToArrow.for_timespan` factory method.
357 """
359 def __init__(self, name: str, nullable: bool):
360 self._name = name
361 self._nullable = nullable
363 storage_type: ClassVar[pa.DataType] = pa.struct(
364 [
365 pa.field("begin_nsec", pa.int64(), nullable=False),
366 pa.field("end_nsec", pa.int64(), nullable=False),
367 ]
368 )
370 @property
371 def name(self) -> str:
372 # Docstring inherited.
373 return self._name
375 @property
376 def nullable(self) -> bool:
377 # Docstring inherited.
378 return self._nullable
380 @property
381 def data_type(self) -> pa.DataType:
382 # Docstring inherited.
383 return TimespanArrowType()
385 def append(self, value: Timespan, column: list[pa.StructScalar]) -> None:
386 # Docstring inherited.
387 column.append({"begin_nsec": value._nsec[0], "end_nsec": value._nsec[1]})
389 def finish(self, column: list[Any]) -> pa.Array:
390 # Docstring inherited.
391 storage_array = pa.array(column, self.storage_type)
392 return pa.ExtensionArray.from_storage(TimespanArrowType(), storage_array)
395class _ToArrowDateTime(ToArrow):
396 """`ToArrow` implementation for `astropy.time.Time` fields.
398 Should be constructed via the `ToArrow.for_datetime` factory method.
399 """
401 def __init__(self, name: str, nullable: bool):
402 self._name = name
403 self._nullable = nullable
405 storage_type: ClassVar[pa.DataType] = pa.int64()
407 @property
408 def name(self) -> str:
409 # Docstring inherited.
410 return self._name
412 @property
413 def nullable(self) -> bool:
414 # Docstring inherited.
415 return self._nullable
417 @property
418 def data_type(self) -> pa.DataType:
419 # Docstring inherited.
420 return DateTimeArrowType()
422 def append(self, value: astropy.time.Time, column: list[int]) -> None:
423 # Docstring inherited.
424 column.append(TimeConverter().astropy_to_nsec(value))
426 def finish(self, column: list[Any]) -> pa.Array:
427 # Docstring inherited.
428 storage_array = pa.array(column, self.storage_type)
429 return pa.ExtensionArray.from_storage(DateTimeArrowType(), storage_array)
432@final
433class UUIDArrowType(pa.ExtensionType):
434 """An Arrow extension type for `astropy.time.Time`, stored as TAI
435 nanoseconds since 1970-01-01.
436 """
438 def __init__(self) -> None:
439 super().__init__(_ToArrowTimespan.storage_type, "astropy.time.Time")
441 def __arrow_ext_serialize__(self) -> bytes:
442 return b""
444 @classmethod
445 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> UUIDArrowType:
446 return cls()
448 def __arrow_ext_scalar_class__(self) -> type[UUIDArrowScalar]:
449 return UUIDArrowScalar
452@final
453class UUIDArrowScalar(pa.ExtensionScalar):
454 """An Arrow scalar type for `uuid.UUID`.
456 Use the standard `as_py` method to convert to an actual `uuid.UUID`
457 instance.
458 """
460 def as_py(self) -> astropy.time.Time:
461 return uuid.UUID(bytes=self.value.as_py())
464@final
465class RegionArrowType(pa.ExtensionType):
466 """An Arrow extension type for lsst.sphgeom.Region."""
468 def __init__(self) -> None:
469 super().__init__(_ToArrowRegion.storage_type, "lsst.sphgeom.Region")
471 def __arrow_ext_serialize__(self) -> bytes:
472 return b""
474 @classmethod
475 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> RegionArrowType:
476 return cls()
478 def __arrow_ext_scalar_class__(self) -> type[RegionArrowScalar]:
479 return RegionArrowScalar
482@final
483class RegionArrowScalar(pa.ExtensionScalar):
484 """An Arrow scalar type for `lsst.sphgeom.Region`.
486 Use the standard `as_py` method to convert to an actual region.
487 """
489 def as_py(self) -> Region:
490 return Region.decode(self.value.as_py())
493@final
494class TimespanArrowType(pa.ExtensionType):
495 """An Arrow extension type for lsst.daf.butler.Timespan."""
497 def __init__(self) -> None:
498 super().__init__(_ToArrowTimespan.storage_type, "lsst.daf.butler.Timespan")
500 def __arrow_ext_serialize__(self) -> bytes:
501 return b""
503 @classmethod
504 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> TimespanArrowType:
505 return cls()
507 def __arrow_ext_scalar_class__(self) -> type[TimespanArrowScalar]:
508 return TimespanArrowScalar
511@final
512class TimespanArrowScalar(pa.ExtensionScalar):
513 """An Arrow scalar type for `lsst.daf.butler.Timespan`.
515 Use the standard `as_py` method to convert to an actual timespan.
516 """
518 def as_py(self) -> Timespan:
519 return Timespan(
520 None, None, _nsec=(self.value["begin_nsec"].as_py(), self.value["begin_nsec"].as_py())
521 )
524@final
525class DateTimeArrowType(pa.ExtensionType):
526 """An Arrow extension type for `astropy.time.Time`, stored as TAI
527 nanoseconds since 1970-01-01.
528 """
530 def __init__(self) -> None:
531 super().__init__(_ToArrowTimespan.storage_type, "astropy.time.Time")
533 def __arrow_ext_serialize__(self) -> bytes:
534 return b""
536 @classmethod
537 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> DateTimeArrowType:
538 return cls()
540 def __arrow_ext_scalar_class__(self) -> type[DateTimeArrowScalar]:
541 return DateTimeArrowScalar
544@final
545class DateTimeArrowScalar(pa.ExtensionScalar):
546 """An Arrow scalar type for `astropy.time.Time`, stored as TAI
547 nanoseconds since 1970-01-01.
549 Use the standard `as_py` method to convert to an actual `astropy.time.Time`
550 instance.
551 """
553 def as_py(self) -> astropy.time.Time:
554 return TimeConverter().nsec_to_astropy(self.value.as_py())
557pa.register_extension_type(RegionArrowType())
558pa.register_extension_type(TimespanArrowType())
559pa.register_extension_type(DateTimeArrowType())