Coverage for python/lsst/daf/relation/_operations/_join.py: 28%

144 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-14 01:59 -0800

1# This file is part of daf_relation. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("Join", "PartialJoin") 

25 

26import dataclasses 

27from collections.abc import Set 

28from typing import TYPE_CHECKING, final 

29 

30from .._binary_operation import BinaryOperation, IgnoreOne 

31from .._columns import ColumnTag, Predicate 

32from .._exceptions import ColumnError, EngineError 

33from .._operation_relations import UnaryOperationRelation 

34from .._unary_operation import UnaryCommutator, UnaryOperation 

35 

36if TYPE_CHECKING: 36 ↛ 37line 36 didn't jump to line 37, because the condition on line 36 was never true

37 from .._engine import Engine 

38 from .._relation import Relation 

39 

40 

41@final 

42@dataclasses.dataclass(frozen=True) 

43class Join(BinaryOperation): 

44 """A natural join operation. 

45 

46 A natural join combines two relations by matching rows with the same values 

47 in their common columns (and satisfying an optional column expression, via 

48 a `Predicate`), producing a new relation whose columns are the union of the 

49 columns of its operands. This is equivalent to [``INNER``] ``JOIN`` in 

50 SQL. 

51 """ 

52 

53 predicate: Predicate = dataclasses.field(default_factory=lambda: Predicate.literal(True)) 53 ↛ exitline 53 didn't run the lambda on line 53

54 """A boolean expression that must evaluate to true for any matched rows 

55 (`Predicate`). 

56 

57 This does not include the equality constraint on `common_columns`. 

58 """ 

59 

60 min_columns: frozenset[ColumnTag] = dataclasses.field(default=frozenset()) 

61 """The minimal set of columns that should be used in the equality 

62 constraint on `common_columns` (`frozenset` [ `ColumnTag` ]). 

63 

64 If the relations this operation is applied to have common columsn that are 

65 not a superset of this set, `ColumnError` will be raised by `apply`. 

66 

67 This is guaranteed to be equal to `max_columns` on any `Join` instance 

68 attached to a `BinaryOperationRelation` by `apply`. 

69 """ 

70 

71 max_columns: frozenset[ColumnTag] | None = dataclasses.field(default=None) 

72 """The maximal set of columns that should be used in the equality 

73 constraint on `common_columns` (`frozenset` [ `ColumnTag` ] or 

74 ``None``). 

75 

76 If the relations this operation is applied to have more columns in common 

77 than this set, they will not be included in the equality constraint. 

78 

79 This is guaranteed to be equal to `min_columns` on any `Join` instance 

80 attached to a `BinaryOperationRelation` by `apply`. 

81 """ 

82 

83 def __post_init__(self) -> None: 

84 if self.max_columns is not None and not self.min_columns <= self.max_columns: 

85 raise ColumnError( 

86 f"Join min_columns={self.min_columns} is not a subset of max_columns={self.max_columns}." 

87 ) 

88 

89 @property 

90 def common_columns(self) -> frozenset[ColumnTag]: 

91 """The common columns between relations that will be used as an 

92 equality constraint (`~collections.abc.Set` [ `ColumnTag` ]). 

93 

94 This attribute is not available on `Join` instances for which 

95 `min_columns` is not the same as `max_columns`. It is always available 

96 on any `Join` instance attached to a `BinaryOperationRelation` by 

97 `apply`. 

98 """ 

99 if self.max_columns == self.min_columns: 

100 return self.min_columns 

101 else: 

102 raise ColumnError(f"Common columns for join {self} have not been resolved.") 

103 

104 def __str__(self) -> str: 

105 return "⋈" 

106 

107 def _begin_apply(self, lhs: Relation, rhs: Relation) -> BinaryOperation: 

108 # Docstring inherited. 

109 if not self.predicate.columns_required <= self.applied_columns(lhs, rhs): 

110 raise ColumnError( 

111 f"Missing columns {set(self.predicate.columns_required - self.applied_columns(lhs, rhs))} " 

112 f"for join between {lhs!r} and {rhs!r} with predicate {self.predicate}." 

113 ) 

114 if self.max_columns != self.min_columns: 

115 common_columns = self.applied_common_columns(lhs, rhs) 

116 operation = dataclasses.replace(self, min_columns=common_columns, max_columns=common_columns) 

117 else: 

118 if not lhs.columns >= self.common_columns: 

119 raise ColumnError( 

120 f"Missing columns {set(self.common_columns - lhs.columns)} " 

121 f"for left-hand side of join between {lhs!r} and {rhs!r}." 

122 ) 

123 if not rhs.columns >= self.common_columns: 

124 raise ColumnError( 

125 f"Missing columns {set(self.common_columns - rhs.columns)} " 

126 f"for right-hand side of join between {lhs!r} and {rhs!r}." 

127 ) 

128 operation = self 

129 if lhs.is_join_identity: 

130 return IgnoreOne(True) 

131 if rhs.is_join_identity: 

132 return IgnoreOne(False) 

133 return operation 

134 

135 def _finish_apply(self, lhs: Relation, rhs: Relation) -> Relation: 

136 # Docstring inherited. 

137 if lhs.is_join_identity: 

138 return rhs 

139 if rhs.is_join_identity: 

140 return lhs 

141 if lhs.engine != rhs.engine: 

142 raise EngineError(f"Mismatched join engines: {lhs.engine} != {rhs.engine}.") 

143 if not self.predicate.is_supported_by(lhs.engine): 

144 raise EngineError(f"Join predicate {self.predicate} does not support engine {lhs.engine}.") 

145 return super()._finish_apply(lhs, rhs) 

146 

147 def applied_columns(self, lhs: Relation, rhs: Relation) -> Set[ColumnTag]: 

148 # Docstring inherited. 

149 return lhs.columns | rhs.columns 

150 

151 def applied_min_rows(self, lhs: Relation, rhs: Relation) -> int: 

152 # Docstring inherited. 

153 return 0 

154 

155 def applied_max_rows(self, lhs: Relation, rhs: Relation) -> int | None: 

156 # Docstring inherited. 

157 if lhs.max_rows == 0 or rhs.max_rows == 0: 

158 return 0 

159 if lhs.max_rows is None or rhs.max_rows is None: 

160 return None 

161 else: 

162 return lhs.max_rows * rhs.max_rows 

163 

164 def applied_common_columns(self, lhs: Relation, rhs: Relation) -> frozenset[ColumnTag]: 

165 """Compute the actual common columns for a `Join` given its targets. 

166 

167 Parameters 

168 ---------- 

169 lhs : `Relation` 

170 One relation to join. 

171 rhs : `Relation` 

172 The other relation to join to ``lhs``. 

173 

174 Returns 

175 ------- 

176 common_columns : `~collections.abc.Set` [ `ColumnTag` ] 

177 Columns that are included in all of ``lhs.columns`` and 

178 ``rhs.columns`` and `max_columns`, checked to be a superset of 

179 `min_columns`. 

180 

181 Raises 

182 ------ 

183 ColumnError 

184 Raised if the result would not be a superset of `min_columns`. 

185 """ 

186 # Docstring inherited. 

187 if self.max_columns != self.min_columns: 

188 common_columns = {tag for tag in lhs.columns & rhs.columns if tag.is_key} 

189 if self.max_columns is not None: 

190 common_columns &= self.max_columns 

191 if not (common_columns >= self.min_columns): 

192 raise ColumnError( 

193 f"Common columns {common_columns} for join between {lhs} and {rhs} are not a superset " 

194 f"of the minimum columns {self.min_columns}." 

195 ) 

196 return frozenset(common_columns) 

197 else: 

198 return self.min_columns 

199 

200 def partial(self, fix: Relation, is_lhs: bool = False) -> PartialJoin: 

201 """Return a `UnaryOperation` that represents this join with one operand 

202 already provided and held fixed. 

203 

204 Parameters 

205 ---------- 

206 fix : `Relation` 

207 Relation to include in the returned unary operation. 

208 is_lhs : `bool`, optional 

209 Whether ``fix`` should be considered the ``lhs`` or ``rhs`` side of 

210 the join (`Join` side is *usually* irrelevant, but `Engine` 

211 implementations are permitted to make additional guarantees about 

212 row order or duplicates based on them). 

213 

214 Returns 

215 ------- 

216 partial_join : `PartialJoin` 

217 Unary operation representing a join to a fixed relation. 

218 

219 Raises 

220 ------ 

221 ColumnError 

222 Raised if the given predicate requires columns not present in 

223 ``lhs`` or ``rhs``. 

224 RowOrderError 

225 Raised if ``lhs`` or ``rhs`` is unnecessarily ordered; see 

226 `Relation.expect_unordered`. 

227 

228 Notes 

229 ----- 

230 This method and the class it returns are called "partial" in the spirit 

231 of `functools.partial`: a callable formed by holding some arguments to 

232 another callable fixed. 

233 """ 

234 if not (self.min_columns <= fix.columns): 

235 raise ColumnError( 

236 f"Missing columns {set(self.min_columns - fix.columns)} for partial join to {fix}." 

237 ) 

238 return PartialJoin(self, fix, is_lhs) 

239 

240 

241@final 

242@dataclasses.dataclass(frozen=True) 

243class PartialJoin(UnaryOperation): 

244 """A `UnaryOperation` that represents this join with one operand already 

245 provided and held fixed. 

246 

247 Notes 

248 ----- 

249 This class and the `Join.partial` used to construct it are called "partial" 

250 in the spirit of `functools.partial`: a callable formed by holding some 

251 arguments to another callable fixed. 

252 

253 `PartialJoin` instances never appear in relation trees; the `apply` method 

254 will return a `BinaryOperationRelation` with a `Join` operation instead of 

255 a `UnaryOperationRelation` with a `PartialJoin` (or one of the operands, if 

256 the other is a `join identity relation <Relation.is_join_identity>`). 

257 """ 

258 

259 binary: Join 

260 """The join operation (`Join`) to be applied. 

261 """ 

262 

263 fixed: Relation 

264 """The target relation already included in the operation (`Relation`). 

265 """ 

266 

267 fixed_is_lhs: bool 

268 """Whether `fixed` should be considered the ``lhs`` or ``rhs`` side of 

269 the join. 

270 

271 `Join` side is *usually* irrelevant, but `Engine` implementations are 

272 permitted to make additional guarantees about row order or duplicates based 

273 on them. 

274 """ 

275 

276 @property 

277 def columns_required(self) -> Set[ColumnTag]: 

278 # Docstring inherited. 

279 result = set(self.binary.predicate.columns_required) 

280 result.difference_update(self.fixed.columns) 

281 result.update(self.binary.min_columns) 

282 return result 

283 

284 @property 

285 def is_empty_invariant(self) -> bool: 

286 # Docstring inherited. 

287 return False 

288 

289 @property 

290 def is_count_invariant(self) -> bool: 

291 # Docstring inherited. 

292 return False 

293 

294 def __str__(self) -> str: 

295 return f"{self.binary!s}[{self.fixed!s}]" 

296 

297 def _begin_apply( 

298 self, target: Relation, preferred_engine: Engine | None 

299 ) -> tuple[UnaryOperation, Engine]: 

300 # Docstring inherited. 

301 if self.binary.max_columns != self.binary.min_columns: 

302 common_columns = self.binary.applied_common_columns(self.fixed, target) 

303 replacement = dataclasses.replace( 

304 self, 

305 binary=dataclasses.replace( 

306 self.binary, min_columns=common_columns, max_columns=common_columns 

307 ), 

308 ) 

309 return replacement._begin_apply(target, preferred_engine) 

310 if preferred_engine is None: 

311 preferred_engine = self.fixed.engine 

312 if not self.columns_required <= target.columns: 

313 raise ColumnError( 

314 f"Join {self} to relation {target} needs columns " 

315 f"{set(self.columns_required) - target.columns}." 

316 ) 

317 return super()._begin_apply(target, preferred_engine) 

318 

319 def _finish_apply(self, target: Relation) -> Relation: 

320 # Docstring inherited. 

321 if self.fixed_is_lhs: 

322 return self.binary.apply(self.fixed, target) 

323 else: 

324 return self.binary.apply(target, self.fixed) 

325 

326 def applied_columns(self, target: Relation) -> Set[ColumnTag]: 

327 # Docstring inherited. 

328 if self.fixed_is_lhs: 

329 return self.binary.applied_columns(self.fixed, target) 

330 else: 

331 return self.binary.applied_columns(target, self.fixed) 

332 

333 def applied_min_rows(self, target: Relation) -> int: 

334 # Docstring inherited. 

335 if self.fixed_is_lhs: 

336 return self.binary.applied_min_rows(self.fixed, target) 

337 else: 

338 return self.binary.applied_min_rows(target, self.fixed) 

339 

340 def applied_max_rows(self, target: Relation) -> int | None: 

341 # Docstring inherited. 

342 if self.fixed_is_lhs: 

343 return self.binary.applied_max_rows(self.fixed, target) 

344 else: 

345 return self.binary.applied_max_rows(target, self.fixed) 

346 

347 def commute(self, current: UnaryOperationRelation) -> UnaryCommutator: 

348 # Docstring inherited. 

349 from ._deduplication import Deduplication 

350 from ._projection import Projection 

351 

352 match current.operation: 

353 case Deduplication(): 

354 # A Join only commutes past Deduplication if the fixed relation 

355 # has unique rows, which is not something we can check right 

356 # now. 

357 return UnaryCommutator( 

358 first=None, 

359 second=current.operation, 

360 done=False, 

361 messages=("join-deduplication commutation is not supported",), 

362 ) 

363 case Projection(): 

364 # In order for projection(join(target)) to be equivalent to 

365 # join(projection(target)), the new outer projection has to 

366 # include the columns added by the join. Note that because we 

367 # require common_columns to be explicit at this point, the 

368 # projection cannot change them. 

369 return UnaryCommutator( 

370 first=self, 

371 second=Projection(frozenset(self.applied_columns(current))), 

372 ) 

373 case _: 

374 if not self.columns_required <= current.target.columns: 

375 return UnaryCommutator( 

376 first=None, 

377 second=current.operation, 

378 done=False, 

379 messages=( 

380 f"{current.target} is missing columns " 

381 f"{set(self.columns_required - current.target.columns)}", 

382 ), 

383 ) 

384 if current.operation.is_count_dependent: 

385 return UnaryCommutator( 

386 first=None, 

387 second=current.operation, 

388 done=False, 

389 messages=(f"{current.operation} is count-dependent",), 

390 ) 

391 return UnaryCommutator(first=self, second=current.operation)