Coverage for python/lsst/daf/butler/arrow_utils.py: 80%

139 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-25 10:50 +0000

1# This file is part of butler4. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ("ToArrow", "RegionArrowType", "RegionArrowScalar", "TimespanArrowType", "TimespanArrowScalar") 

31 

32from abc import ABC, abstractmethod 

33from typing import Any, ClassVar, final 

34 

35import pyarrow as pa 

36from lsst.sphgeom import Region 

37 

38from ._timespan import Timespan 

39 

40 

41class ToArrow(ABC): 

42 """An abstract interface for converting objects to an Arrow field of the 

43 appropriate type. 

44 """ 

45 

46 @staticmethod 

47 def for_primitive(name: str, data_type: pa.DataType, nullable: bool) -> ToArrow: 

48 """Return a converter for a primitive type already supported by Arrow. 

49 

50 Parameters 

51 ---------- 

52 name : `str` 

53 Name of the schema field. 

54 data_type : `pyarrow.DataType` 

55 Arrow data type object. 

56 nullable : `bool` 

57 Whether the field should permit null or `None` values. 

58 

59 Returns 

60 ------- 

61 to_arrow : `ToArrow` 

62 Converter instance. 

63 """ 

64 return _ToArrowPrimitive(name, data_type, nullable) 

65 

66 @staticmethod 

67 def for_region(name: str, nullable: bool = True) -> ToArrow: 

68 """Return a converter for `lsst.sphgeom.Region`. 

69 

70 Parameters 

71 ---------- 

72 name : `str` 

73 Name of the schema field. 

74 nullable : `bool` 

75 Whether the field should permit null or `None` values. 

76 

77 Returns 

78 ------- 

79 to_arrow : `ToArrow` 

80 Converter instance. 

81 """ 

82 return _ToArrowRegion(name, nullable) 

83 

84 @staticmethod 

85 def for_timespan(name: str, nullable: bool = True) -> ToArrow: 

86 """Return a converter for `lsst.daf.butler.Timespan`. 

87 

88 Parameters 

89 ---------- 

90 name : `str` 

91 Name of the schema field. 

92 nullable : `bool` 

93 Whether the field should permit null or `None` values. 

94 

95 Returns 

96 ------- 

97 to_arrow : `ToArrow` 

98 Converter instance. 

99 """ 

100 return _ToArrowTimespan(name, nullable) 

101 

102 @property 

103 @abstractmethod 

104 def name(self) -> str: 

105 """Name of the field.""" 

106 raise NotImplementedError() 

107 

108 @property 

109 @abstractmethod 

110 def nullable(self) -> bool: 

111 """Whether the field permits null or `None` values.""" 

112 raise NotImplementedError() 

113 

114 @property 

115 @abstractmethod 

116 def data_type(self) -> pa.DataType: 

117 """Arrow data type for the field this converter produces.""" 

118 raise NotImplementedError() 

119 

120 @property 

121 def field(self) -> pa.Field: 

122 """Arrow field this converter produces.""" 

123 return pa.field(self.name, self.data_type, self.nullable) 

124 

125 def dictionary_encoded(self) -> ToArrow: 

126 """Return a new converter with the same name and type, but using 

127 dictionary encoding (to 32-bit integers) to compress duplicate values. 

128 """ 

129 return _ToArrowDictionary(self) 

130 

131 @abstractmethod 

132 def append(self, value: Any, column: list[Any]) -> None: 

133 """Append an object's arrow representation to a `list`. 

134 

135 Parameters 

136 ---------- 

137 value : `object` 

138 Original value to be converted to a row in an Arrow column. 

139 column : `list` 

140 List of values to append to. The type of value to append is 

141 implementation-defined; the only requirement is that `finish` be 

142 able to handle this `list` later. 

143 """ 

144 raise NotImplementedError() 

145 

146 @abstractmethod 

147 def finish(self, column: list[Any]) -> pa.Array: 

148 """Convert a list of values constructed via `append` into an Arrow 

149 array. 

150 

151 Parameters 

152 ---------- 

153 column : `list` 

154 List of column values populated by `append`. 

155 """ 

156 raise NotImplementedError() 

157 

158 

159class _ToArrowPrimitive(ToArrow): 

160 """`ToArrow` implementation for primitive types supported direct by Arrow. 

161 

162 Should be constructed via the `ToArrow.for_primitive` factory method. 

163 """ 

164 

165 def __init__(self, name: str, data_type: pa.DataType, nullable: bool): 

166 self._name = name 

167 self._data_type = data_type 

168 self._nullable = nullable 

169 

170 @property 

171 def name(self) -> str: 

172 # Docstring inherited. 

173 return self._name 

174 

175 @property 

176 def nullable(self) -> bool: 

177 # Docstring inherited. 

178 return self._nullable 

179 

180 @property 

181 def data_type(self) -> pa.DataType: 

182 # Docstring inherited. 

183 return self._data_type 

184 

185 def append(self, value: Any, column: list[Any]) -> None: 

186 # Docstring inherited. 

187 column.append(value) 

188 

189 def finish(self, column: list[Any]) -> pa.Array: 

190 # Docstring inherited. 

191 return pa.array(column, self._data_type) 

192 

193 

194class _ToArrowDictionary(ToArrow): 

195 """`ToArrow` implementation for Arrow dictionary fields. 

196 

197 Should be constructed via the `ToArrow.dictionary_encoded` factory method. 

198 """ 

199 

200 def __init__(self, to_arrow_value: ToArrow): 

201 self._to_arrow_value = to_arrow_value 

202 

203 @property 

204 def name(self) -> str: 

205 # Docstring inherited. 

206 return self._to_arrow_value.name 

207 

208 @property 

209 def nullable(self) -> bool: 

210 # Docstring inherited. 

211 return self._to_arrow_value.nullable 

212 

213 @property 

214 def data_type(self) -> pa.DataType: 

215 # Docstring inherited. 

216 # We hard-code int32 as the index type here because that's what 

217 # the pa.Arrow.dictionary_encode() method does. 

218 return pa.dictionary(pa.int32(), self._to_arrow_value.data_type) 

219 

220 def append(self, value: Any, column: list[Any]) -> None: 

221 # Docstring inherited. 

222 self._to_arrow_value.append(value, column) 

223 

224 def finish(self, column: list[Any]) -> pa.Array: 

225 # Docstring inherited. 

226 return self._to_arrow_value.finish(column).dictionary_encode() 

227 

228 

229class _ToArrowRegion(ToArrow): 

230 """`ToArrow` implementation for `lsst.sphgeom.Region` fields. 

231 

232 Should be constructed via the `ToArrow.for_region` factory method. 

233 """ 

234 

235 def __init__(self, name: str, nullable: bool): 

236 self._name = name 

237 self._nullable = nullable 

238 

239 storage_type: ClassVar[pa.DataType] = pa.binary() 

240 

241 @property 

242 def name(self) -> str: 

243 # Docstring inherited. 

244 return self._name 

245 

246 @property 

247 def nullable(self) -> bool: 

248 # Docstring inherited. 

249 return self._nullable 

250 

251 @property 

252 def data_type(self) -> pa.DataType: 

253 # Docstring inherited. 

254 return RegionArrowType() 

255 

256 def append(self, value: Region, column: list[bytes]) -> None: 

257 # Docstring inherited. 

258 column.append(value.encode()) 

259 

260 def finish(self, column: list[Any]) -> pa.Array: 

261 # Docstring inherited. 

262 storage_array = pa.array(column, self.storage_type) 

263 return pa.ExtensionArray.from_storage(RegionArrowType(), storage_array) 

264 

265 

266class _ToArrowTimespan(ToArrow): 

267 """`ToArrow` implementation for `lsst.daf.butler.timespan` fields. 

268 

269 Should be constructed via the `ToArrow.for_timespan` factory method. 

270 """ 

271 

272 def __init__(self, name: str, nullable: bool): 

273 self._name = name 

274 self._nullable = nullable 

275 

276 storage_type: ClassVar[pa.DataType] = pa.struct( 

277 [ 

278 pa.field("begin_nsec", pa.int64(), nullable=False), 

279 pa.field("end_nsec", pa.int64(), nullable=False), 

280 ] 

281 ) 

282 

283 @property 

284 def name(self) -> str: 

285 # Docstring inherited. 

286 return self._name 

287 

288 @property 

289 def nullable(self) -> bool: 

290 # Docstring inherited. 

291 return self._nullable 

292 

293 @property 

294 def data_type(self) -> pa.DataType: 

295 # Docstring inherited. 

296 return TimespanArrowType() 

297 

298 def append(self, value: Timespan, column: list[pa.StructScalar]) -> None: 

299 # Docstring inherited. 

300 column.append({"begin_nsec": value._nsec[0], "end_nsec": value._nsec[1]}) 

301 

302 def finish(self, column: list[Any]) -> pa.Array: 

303 # Docstring inherited. 

304 storage_array = pa.array(column, self.storage_type) 

305 return pa.ExtensionArray.from_storage(TimespanArrowType(), storage_array) 

306 

307 

308@final 

309class RegionArrowType(pa.ExtensionType): 

310 """An Arrow extension type for lsst.sphgeom.Region.""" 

311 

312 def __init__(self) -> None: 

313 super().__init__(_ToArrowRegion.storage_type, "lsst.sphgeom.Region") 

314 

315 def __arrow_ext_serialize__(self) -> bytes: 

316 return b"" 

317 

318 @classmethod 

319 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> RegionArrowType: 

320 return cls() 

321 

322 def __arrow_ext_scalar_class__(self) -> type[RegionArrowScalar]: 

323 return RegionArrowScalar 

324 

325 

326@final 

327class RegionArrowScalar(pa.ExtensionScalar): 

328 """An Arrow scalar type for `lsst.sphgeom.Region`. 

329 

330 Use the standard `as_py` method to convert to an actual region. 

331 """ 

332 

333 def as_py(self) -> Region: 

334 return Region.decode(self.value.as_py()) 

335 

336 

337@final 

338class TimespanArrowType(pa.ExtensionType): 

339 """An Arrow extension type for lsst.daf.butler.Timespan.""" 

340 

341 def __init__(self) -> None: 

342 super().__init__(_ToArrowTimespan.storage_type, "lsst.daf.butler.Timespan") 

343 

344 def __arrow_ext_serialize__(self) -> bytes: 

345 return b"" 

346 

347 @classmethod 

348 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> TimespanArrowType: 

349 return cls() 

350 

351 def __arrow_ext_scalar_class__(self) -> type[TimespanArrowScalar]: 

352 return TimespanArrowScalar 

353 

354 

355@final 

356class TimespanArrowScalar(pa.ExtensionScalar): 

357 """An Arrow scalar type for `lsst.daf.butler.Timespan`. 

358 

359 Use the standard `as_py` method to convert to an actual timespan. 

360 """ 

361 

362 def as_py(self) -> Timespan: 

363 return Timespan( 

364 None, None, _nsec=(self.value["begin_nsec"].as_py(), self.value["begin_nsec"].as_py()) 

365 ) 

366 

367 

368pa.register_extension_type(RegionArrowType()) 

369pa.register_extension_type(TimespanArrowType())