Coverage for python/lsst/daf/relation/iteration/_row_iterable.py: 44%
83 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-14 01:59 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-14 01:59 -0800
1# This file is part of daf_relation.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = (
25 "CalculationRowIterable",
26 "ChainRowIterable",
27 "MaterializedRowIterable",
28 "ProjectionRowIterable",
29 "RowIterable",
30 "RowMapping",
31 "RowSequence",
32 "SelectionRowIterable",
33 "ChainRowIterable",
34)
36import itertools
37from abc import abstractmethod
38from collections.abc import Callable, Iterator, Mapping, Sequence, Set
39from typing import Any
41from .._columns import ColumnTag
44class RowIterable:
45 """An abstract base class for iterables that use mappings for rows.
47 `RowIterable` is the `~.Relation.payload` type for the `.iteration` engine.
48 """
50 @abstractmethod
51 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]:
52 raise NotImplementedError()
54 def to_mapping(self, unique_key: Sequence[ColumnTag]) -> RowMapping:
55 """Convert this iterable to a `RowMapping`, unless it already is one.
57 Parameters
58 ----------
59 unique_key : `~collections.abc.Sequence` [ `ColumnTag` ]
60 Sequence of columns to extract into a `tuple` to use as keys in the
61 mapping, guaranteeing uniqueness over these columns.
63 Returns
64 -------
65 rows : `RowMapping`
66 A `RowIterable` backed by a mapping.
67 """
68 return RowMapping(unique_key, {tuple(row[k] for k in unique_key): row for row in self})
70 def to_sequence(self) -> RowSequence:
71 """Convert this iterable to a `RowSequence`, unless it already is one.
73 Returns
74 -------
75 rows : `RowSequence`
76 A `RowIterable` backed by a sequence.
77 """
78 return RowSequence(list(self))
80 def materialized(self) -> MaterializedRowIterable:
81 """Convert this iterable to one that holds its rows in a Python
82 collection of some kind, instead of generating them lazily.
84 Returns
85 -------
86 rows : `MaterializedRowIterable`
87 A `RowIterable` that isn't lazy.
88 """
89 return self.to_sequence()
91 def sliced(self, start: int, stop: int | None) -> RowIterable:
92 """Apply a slice operation to this `RowIterable`.
94 Parameters
95 ----------
96 start : `int`
97 Start index.
98 stop : `int` or `None`
99 Stop index (one-past-the-end), or `None` to include up through the
100 last row.
102 Returns
103 -------
104 rows : `RowIterable`
105 Iterable representing the slice. May or may not be lazy.
106 """
107 return SliceRowIterable(self, start, stop)
110class MaterializedRowIterable(RowIterable):
111 """A `RowIterable` that is not lazy and has a known length."""
113 @abstractmethod
114 def __len__(self) -> int:
115 raise NotImplementedError()
117 def materialized(self) -> MaterializedRowIterable:
118 # Docstring inherited.
119 return self
122class RowMapping(MaterializedRowIterable):
123 """A `RowIterable` backed by a `~collections.abc.Mapping`
125 Parameters
126 ----------
127 unique_key : `~collections.abc.Sequence` [ `ColumnTag` ]
128 Sequence of columns to extract into a `tuple` to use as keys in the
129 mapping, guaranteeing uniqueness over these columns.
130 rows : `collections.abc.Mapping`
131 Nested mapping with `tuple` keys and row values, where each row is
132 (as usual for `RowIterable` types) itself a `Mapping` with `.ColumnTag`
133 keys.
134 """
136 def __init__(self, unique_key: Sequence[ColumnTag], rows: Mapping[tuple, Mapping[ColumnTag, Any]]):
137 self.rows = rows
138 self.unique_key = unique_key
140 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]:
141 return iter(self.rows.values())
143 def __len__(self) -> int:
144 return len(self.rows)
146 def to_mapping(self, unique_key: Sequence[ColumnTag]) -> RowMapping:
147 # Docstring inherited.
148 if unique_key == self.unique_key:
149 return self
150 else:
151 return super().to_mapping(unique_key)
154class RowSequence(MaterializedRowIterable):
155 """A `RowIterable` backed by a `~collections.abc.Sequence`
157 Parameters
158 ----------
159 rows : `Mapping`
160 Sequence of rows, where each row is (as usual for `RowIterable` types)
161 a `Mapping` with `.ColumnTag` keys.
162 """
164 def __init__(self, rows: Sequence[Mapping[ColumnTag, Any]]):
165 self.rows = rows
167 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]:
168 return iter(self.rows)
170 def __len__(self) -> int:
171 return len(self.rows)
173 def to_sequence(self) -> RowSequence:
174 # Docstring inherited.
175 return self
177 def sliced(self, start: int, stop: int | None) -> RowIterable:
178 # Docstring inherited.
179 return RowSequence(self.rows[start:stop])
182class CalculationRowIterable(RowIterable):
183 """A `RowIterable` implementation that implements a calculation operation.
185 Parameters
186 ----------
187 target : `RowIterable`
188 Original iterable.
189 tag : `ColumnTag`
190 Key for the new column in result-row mappings.
191 callable : `Callable`
192 Callable that takes a single mapping argument and returns a new column
193 value.
194 """
196 def __init__(
197 self, target: RowIterable, tag: ColumnTag, callable: Callable[[Mapping[ColumnTag, Any]], Any]
198 ):
199 self.target = target
200 self.tag = tag
201 self.callable = callable
203 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]:
204 return ({**row, self.tag: self.callable(row)} for row in self.target)
207class ChainRowIterable(RowIterable):
208 """A `RowIterable` implementation that wraps `itertools.chain`.
210 Parameters
211 ----------
212 chain : `Sequence` [ `RowIterable` ]
213 Sequence of iterables to chain together.
214 """
216 def __init__(self, chain: Sequence[RowIterable]):
217 self.chain = chain
219 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]:
220 return itertools.chain.from_iterable(self.chain)
223class ProjectionRowIterable(RowIterable):
224 """A `RowIterable` implementation that implements a projection operation.
226 Parameters
227 ----------
228 target : `RowIterable`
229 Original iterable to take a column subset from.
230 columns : `Set`
231 Columns to include in the new iterable.
232 """
234 def __init__(self, target: RowIterable, columns: Set[ColumnTag]):
235 self.target = target
236 self.columns = columns
238 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]:
239 return ({k: row[k] for k in self.columns} for row in self.target)
242class SelectionRowIterable(RowIterable):
243 """A `RowIterable` implementation that implements a selection operation.
245 Parameters
246 ----------
247 target : `RowIterable`
248 Original iterable to filter rows from.
249 callable : `Callable`
250 Callable that takes a single mapping argument and returns a `bool`.
251 """
253 def __init__(self, target: RowIterable, callable: Callable[[Mapping[ColumnTag, Any]], bool]):
254 self.target = target
255 self.callable = callable
257 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]:
258 return (row for row in self.target if self.callable(row))
261class SliceRowIterable(RowIterable):
262 """A `RowIterable` that implements a lazy `Slice` operation.
264 Parameters
265 ----------
266 target : `RowIterable`
267 Original iterable.
268 start : `int`
269 Start index.
270 stop : `int` or `None`
271 Stop index (one-past-the-end), or `None` to include up through the
272 last row.
273 """
275 def __init__(self, target: RowIterable, start: int, stop: int | None):
276 self.target = target
277 self.start = start
278 self.stop = stop
280 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]:
281 for n, row in enumerate(self.target):
282 if self.stop is not None and n == self.stop:
283 return
284 if n >= self.start:
285 yield row