Coverage for python/lsst/daf/relation/_operations/_join.py: 32%
141 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-13 10:03 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-13 10:03 +0000
1# This file is part of daf_relation.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("Join", "PartialJoin")
26import dataclasses
27from collections.abc import Set
28from typing import TYPE_CHECKING, final
30from .._binary_operation import BinaryOperation, IgnoreOne
31from .._columns import ColumnTag, Predicate
32from .._exceptions import ColumnError, EngineError
33from .._operation_relations import UnaryOperationRelation
34from .._unary_operation import UnaryCommutator, UnaryOperation
36if TYPE_CHECKING:
37 from .._engine import Engine
38 from .._relation import Relation
41@final
42@dataclasses.dataclass(frozen=True)
43class Join(BinaryOperation):
44 """A natural join operation.
46 A natural join combines two relations by matching rows with the same values
47 in their common columns (and satisfying an optional column expression, via
48 a `Predicate`), producing a new relation whose columns are the union of the
49 columns of its operands. This is equivalent to [``INNER``] ``JOIN`` in
50 SQL.
51 """
53 predicate: Predicate = dataclasses.field(default_factory=lambda: Predicate.literal(True)) 53 ↛ exitline 53 didn't run the lambda on line 53
54 """A boolean expression that must evaluate to true for any matched rows
55 (`Predicate`).
57 This does not include the equality constraint on `common_columns`.
58 """
60 min_columns: frozenset[ColumnTag] = dataclasses.field(default=frozenset())
61 """The minimal set of columns that should be used in the equality
62 constraint on `common_columns` (`frozenset` [ `ColumnTag` ]).
64 If the relations this operation is applied to have common columsn that are
65 not a superset of this set, `ColumnError` will be raised by `apply`.
67 This is guaranteed to be equal to `max_columns` on any `Join` instance
68 attached to a `BinaryOperationRelation` by `apply`.
69 """
71 max_columns: frozenset[ColumnTag] | None = dataclasses.field(default=None)
72 """The maximal set of columns that should be used in the equality
73 constraint on `common_columns` (`frozenset` [ `ColumnTag` ] or
74 ``None``).
76 If the relations this operation is applied to have more columns in common
77 than this set, they will not be included in the equality constraint.
79 This is guaranteed to be equal to `min_columns` on any `Join` instance
80 attached to a `BinaryOperationRelation` by `apply`.
81 """
83 def __post_init__(self) -> None:
84 if self.max_columns is not None and not self.min_columns <= self.max_columns:
85 raise ColumnError(
86 f"Join min_columns={self.min_columns} is not a subset of max_columns={self.max_columns}."
87 )
89 @property
90 def common_columns(self) -> frozenset[ColumnTag]:
91 """The common columns between relations that will be used as an
92 equality constraint (`~collections.abc.Set` [ `ColumnTag` ]).
94 This attribute is not available on `Join` instances for which
95 `min_columns` is not the same as `max_columns`. It is always available
96 on any `Join` instance attached to a `BinaryOperationRelation` by
97 `apply`.
98 """
99 if self.max_columns == self.min_columns:
100 return self.min_columns
101 else:
102 raise ColumnError(f"Common columns for join {self} have not been resolved.")
104 def __str__(self) -> str:
105 return "⋈"
107 def _begin_apply(self, lhs: Relation, rhs: Relation) -> BinaryOperation:
108 # Docstring inherited.
109 if not self.predicate.columns_required <= self.applied_columns(lhs, rhs):
110 raise ColumnError(
111 f"Missing columns {set(self.predicate.columns_required - self.applied_columns(lhs, rhs))} "
112 f"for join between {lhs!r} and {rhs!r} with predicate {self.predicate}."
113 )
114 if self.max_columns != self.min_columns:
115 common_columns = self.applied_common_columns(lhs, rhs)
116 operation = dataclasses.replace(self, min_columns=common_columns, max_columns=common_columns)
117 else:
118 if not lhs.columns >= self.common_columns:
119 raise ColumnError(
120 f"Missing columns {set(self.common_columns - lhs.columns)} "
121 f"for left-hand side of join between {lhs!r} and {rhs!r}."
122 )
123 if not rhs.columns >= self.common_columns:
124 raise ColumnError(
125 f"Missing columns {set(self.common_columns - rhs.columns)} "
126 f"for right-hand side of join between {lhs!r} and {rhs!r}."
127 )
128 operation = self
129 if lhs.is_join_identity:
130 return IgnoreOne(True)
131 if rhs.is_join_identity:
132 return IgnoreOne(False)
133 return operation
135 def _finish_apply(self, lhs: Relation, rhs: Relation) -> Relation:
136 # Docstring inherited.
137 if lhs.is_join_identity:
138 return rhs
139 if rhs.is_join_identity:
140 return lhs
141 if lhs.engine != rhs.engine:
142 raise EngineError(f"Mismatched join engines: {lhs.engine} != {rhs.engine}.")
143 if not self.predicate.is_supported_by(lhs.engine):
144 raise EngineError(f"Join predicate {self.predicate} does not support engine {lhs.engine}.")
145 return super()._finish_apply(lhs, rhs)
147 def applied_columns(self, lhs: Relation, rhs: Relation) -> Set[ColumnTag]:
148 # Docstring inherited.
149 return lhs.columns | rhs.columns
151 def applied_min_rows(self, lhs: Relation, rhs: Relation) -> int:
152 # Docstring inherited.
153 return 0
155 def applied_max_rows(self, lhs: Relation, rhs: Relation) -> int | None:
156 # Docstring inherited.
157 if lhs.max_rows == 0 or rhs.max_rows == 0:
158 return 0
159 if lhs.max_rows is None or rhs.max_rows is None:
160 return None
161 else:
162 return lhs.max_rows * rhs.max_rows
164 def applied_common_columns(self, lhs: Relation, rhs: Relation) -> frozenset[ColumnTag]:
165 """Compute the actual common columns for a `Join` given its targets.
167 Parameters
168 ----------
169 lhs : `Relation`
170 One relation to join.
171 rhs : `Relation`
172 The other relation to join to ``lhs``.
174 Returns
175 -------
176 common_columns : `~collections.abc.Set` [ `ColumnTag` ]
177 Columns that are included in all of ``lhs.columns`` and
178 ``rhs.columns`` and `max_columns`, checked to be a superset of
179 `min_columns`.
181 Raises
182 ------
183 ColumnError
184 Raised if the result would not be a superset of `min_columns`.
185 """
186 # Docstring inherited.
187 if self.max_columns != self.min_columns:
188 common_columns = {tag for tag in lhs.columns & rhs.columns if tag.is_key}
189 if self.max_columns is not None:
190 common_columns &= self.max_columns
191 if not (common_columns >= self.min_columns):
192 raise ColumnError(
193 f"Common columns {common_columns} for join between {lhs} and {rhs} are not a superset "
194 f"of the minimum columns {self.min_columns}."
195 )
196 return frozenset(common_columns)
197 else:
198 return self.min_columns
200 def partial(self, fix: Relation, is_lhs: bool = False) -> PartialJoin:
201 """Return a `UnaryOperation` that represents this join with one operand
202 already provided and held fixed.
204 Parameters
205 ----------
206 fix : `Relation`
207 Relation to include in the returned unary operation.
208 is_lhs : `bool`, optional
209 Whether ``fix`` should be considered the ``lhs`` or ``rhs`` side of
210 the join (`Join` side is *usually* irrelevant, but `Engine`
211 implementations are permitted to make additional guarantees about
212 row order or duplicates based on them).
214 Returns
215 -------
216 partial_join : `PartialJoin`
217 Unary operation representing a join to a fixed relation.
219 Raises
220 ------
221 ColumnError
222 Raised if the given predicate requires columns not present in
223 ``lhs`` or ``rhs``.
225 Notes
226 -----
227 This method and the class it returns are called "partial" in the spirit
228 of `functools.partial`: a callable formed by holding some arguments to
229 another callable fixed.
230 """
231 if not (self.min_columns <= fix.columns):
232 raise ColumnError(
233 f"Missing columns {set(self.min_columns - fix.columns)} for partial join to {fix}."
234 )
235 return PartialJoin(self, fix, is_lhs)
238@final
239@dataclasses.dataclass(frozen=True)
240class PartialJoin(UnaryOperation):
241 """A `UnaryOperation` that represents this join with one operand already
242 provided and held fixed.
244 Notes
245 -----
246 This class and the `Join.partial` used to construct it are called "partial"
247 in the spirit of `functools.partial`: a callable formed by holding some
248 arguments to another callable fixed.
250 `PartialJoin` instances never appear in relation trees; the `apply` method
251 will return a `BinaryOperationRelation` with a `Join` operation instead of
252 a `UnaryOperationRelation` with a `PartialJoin` (or one of the operands, if
253 the other is a `join identity relation <Relation.is_join_identity>`).
254 """
256 binary: Join
257 """The join operation (`Join`) to be applied.
258 """
260 fixed: Relation
261 """The target relation already included in the operation (`Relation`).
262 """
264 fixed_is_lhs: bool
265 """Whether `fixed` should be considered the ``lhs`` or ``rhs`` side of
266 the join.
268 `Join` side is *usually* irrelevant, but `Engine` implementations are
269 permitted to make additional guarantees about row order or duplicates based
270 on them.
271 """
273 @property
274 def columns_required(self) -> Set[ColumnTag]:
275 # Docstring inherited.
276 result = set(self.binary.predicate.columns_required)
277 result.difference_update(self.fixed.columns)
278 result.update(self.binary.min_columns)
279 return result
281 @property
282 def is_empty_invariant(self) -> bool:
283 # Docstring inherited.
284 return False
286 @property
287 def is_count_invariant(self) -> bool:
288 # Docstring inherited.
289 return False
291 def __str__(self) -> str:
292 return f"{self.binary!s}[{self.fixed!s}]"
294 def _begin_apply(
295 self, target: Relation, preferred_engine: Engine | None
296 ) -> tuple[UnaryOperation, Engine]:
297 # Docstring inherited.
298 if self.binary.max_columns != self.binary.min_columns:
299 common_columns = self.binary.applied_common_columns(self.fixed, target)
300 replacement = dataclasses.replace(
301 self,
302 binary=dataclasses.replace(
303 self.binary, min_columns=common_columns, max_columns=common_columns
304 ),
305 )
306 return replacement._begin_apply(target, preferred_engine)
307 if preferred_engine is None:
308 preferred_engine = self.fixed.engine
309 if not self.columns_required <= target.columns:
310 raise ColumnError(
311 f"Join {self} to relation {target} needs columns "
312 f"{set(self.columns_required) - target.columns}."
313 )
314 return super()._begin_apply(target, preferred_engine)
316 def _finish_apply(self, target: Relation) -> Relation:
317 # Docstring inherited.
318 if self.fixed_is_lhs:
319 return self.binary.apply(self.fixed, target)
320 else:
321 return self.binary.apply(target, self.fixed)
323 def applied_columns(self, target: Relation) -> Set[ColumnTag]:
324 # Docstring inherited.
325 if self.fixed_is_lhs:
326 return self.binary.applied_columns(self.fixed, target)
327 else:
328 return self.binary.applied_columns(target, self.fixed)
330 def applied_min_rows(self, target: Relation) -> int:
331 # Docstring inherited.
332 if self.fixed_is_lhs:
333 return self.binary.applied_min_rows(self.fixed, target)
334 else:
335 return self.binary.applied_min_rows(target, self.fixed)
337 def applied_max_rows(self, target: Relation) -> int | None:
338 # Docstring inherited.
339 if self.fixed_is_lhs:
340 return self.binary.applied_max_rows(self.fixed, target)
341 else:
342 return self.binary.applied_max_rows(target, self.fixed)
344 def commute(self, current: UnaryOperationRelation) -> UnaryCommutator:
345 # Docstring inherited.
346 from ._deduplication import Deduplication
347 from ._projection import Projection
349 match current.operation:
350 case Deduplication():
351 # A Join only commutes past Deduplication if the fixed relation
352 # has unique rows, which is not something we can check right
353 # now.
354 return UnaryCommutator(
355 first=None,
356 second=current.operation,
357 done=False,
358 messages=("join-deduplication commutation is not supported",),
359 )
360 case Projection():
361 # In order for projection(join(target)) to be equivalent to
362 # join(projection(target)), the new outer projection has to
363 # include the columns added by the join. Note that because we
364 # require common_columns to be explicit at this point, the
365 # projection cannot change them.
366 return UnaryCommutator(
367 first=self,
368 second=Projection(frozenset(self.applied_columns(current))),
369 )
370 case _:
371 if not self.columns_required <= current.target.columns:
372 return UnaryCommutator(
373 first=None,
374 second=current.operation,
375 done=False,
376 messages=(
377 f"{current.target} is missing columns "
378 f"{set(self.columns_required - current.target.columns)}",
379 ),
380 )
381 if current.operation.is_count_dependent:
382 return UnaryCommutator(
383 first=None,
384 second=current.operation,
385 done=False,
386 messages=(f"{current.operation} is count-dependent",),
387 )
388 return UnaryCommutator(first=self, second=current.operation)