Coverage for python/lsst/daf/relation/_operations/_join.py: 32%

141 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-19 10:06 +0000

1# This file is part of daf_relation. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("Join", "PartialJoin") 

25 

26import dataclasses 

27from collections.abc import Set 

28from typing import TYPE_CHECKING, final 

29 

30from .._binary_operation import BinaryOperation, IgnoreOne 

31from .._columns import ColumnTag, Predicate 

32from .._exceptions import ColumnError, EngineError 

33from .._operation_relations import UnaryOperationRelation 

34from .._unary_operation import UnaryCommutator, UnaryOperation 

35 

36if TYPE_CHECKING: 

37 from .._engine import Engine 

38 from .._relation import Relation 

39 

40 

41@final 

42@dataclasses.dataclass(frozen=True) 

43class Join(BinaryOperation): 

44 """A natural join operation. 

45 

46 A natural join combines two relations by matching rows with the same values 

47 in their common columns (and satisfying an optional column expression, via 

48 a `Predicate`), producing a new relation whose columns are the union of the 

49 columns of its operands. This is equivalent to [``INNER``] ``JOIN`` in 

50 SQL. 

51 """ 

52 

53 predicate: Predicate = dataclasses.field(default_factory=lambda: Predicate.literal(True)) 53 ↛ exitline 53 didn't run the lambda on line 53

54 """A boolean expression that must evaluate to true for any matched rows 

55 (`Predicate`). 

56 

57 This does not include the equality constraint on `common_columns`. 

58 """ 

59 

60 min_columns: frozenset[ColumnTag] = dataclasses.field(default=frozenset()) 

61 """The minimal set of columns that should be used in the equality 

62 constraint on `common_columns` (`frozenset` [ `ColumnTag` ]). 

63 

64 If the relations this operation is applied to have common columsn that are 

65 not a superset of this set, `ColumnError` will be raised by `apply`. 

66 

67 This is guaranteed to be equal to `max_columns` on any `Join` instance 

68 attached to a `BinaryOperationRelation` by `apply`. 

69 """ 

70 

71 max_columns: frozenset[ColumnTag] | None = dataclasses.field(default=None) 

72 """The maximal set of columns that should be used in the equality 

73 constraint on `common_columns` (`frozenset` [ `ColumnTag` ] or 

74 ``None``). 

75 

76 If the relations this operation is applied to have more columns in common 

77 than this set, they will not be included in the equality constraint. 

78 

79 This is guaranteed to be equal to `min_columns` on any `Join` instance 

80 attached to a `BinaryOperationRelation` by `apply`. 

81 """ 

82 

83 def __post_init__(self) -> None: 

84 if self.max_columns is not None and not self.min_columns <= self.max_columns: 

85 raise ColumnError( 

86 f"Join min_columns={self.min_columns} is not a subset of max_columns={self.max_columns}." 

87 ) 

88 

89 @property 

90 def common_columns(self) -> frozenset[ColumnTag]: 

91 """The common columns between relations that will be used as an 

92 equality constraint (`~collections.abc.Set` [ `ColumnTag` ]). 

93 

94 This attribute is not available on `Join` instances for which 

95 `min_columns` is not the same as `max_columns`. It is always available 

96 on any `Join` instance attached to a `BinaryOperationRelation` by 

97 `apply`. 

98 """ 

99 if self.max_columns == self.min_columns: 

100 return self.min_columns 

101 else: 

102 raise ColumnError(f"Common columns for join {self} have not been resolved.") 

103 

104 def __str__(self) -> str: 

105 return "⋈" 

106 

107 def _begin_apply(self, lhs: Relation, rhs: Relation) -> BinaryOperation: 

108 # Docstring inherited. 

109 if not self.predicate.columns_required <= self.applied_columns(lhs, rhs): 

110 raise ColumnError( 

111 f"Missing columns {set(self.predicate.columns_required - self.applied_columns(lhs, rhs))} " 

112 f"for join between {lhs!r} and {rhs!r} with predicate {self.predicate}." 

113 ) 

114 if self.max_columns != self.min_columns: 

115 common_columns = self.applied_common_columns(lhs, rhs) 

116 operation = dataclasses.replace(self, min_columns=common_columns, max_columns=common_columns) 

117 else: 

118 if not lhs.columns >= self.common_columns: 

119 raise ColumnError( 

120 f"Missing columns {set(self.common_columns - lhs.columns)} " 

121 f"for left-hand side of join between {lhs!r} and {rhs!r}." 

122 ) 

123 if not rhs.columns >= self.common_columns: 

124 raise ColumnError( 

125 f"Missing columns {set(self.common_columns - rhs.columns)} " 

126 f"for right-hand side of join between {lhs!r} and {rhs!r}." 

127 ) 

128 operation = self 

129 if lhs.is_join_identity: 

130 return IgnoreOne(True) 

131 if rhs.is_join_identity: 

132 return IgnoreOne(False) 

133 return operation 

134 

135 def _finish_apply(self, lhs: Relation, rhs: Relation) -> Relation: 

136 # Docstring inherited. 

137 if lhs.is_join_identity: 

138 return rhs 

139 if rhs.is_join_identity: 

140 return lhs 

141 if lhs.engine != rhs.engine: 

142 raise EngineError(f"Mismatched join engines: {lhs.engine} != {rhs.engine}.") 

143 if not self.predicate.is_supported_by(lhs.engine): 

144 raise EngineError(f"Join predicate {self.predicate} does not support engine {lhs.engine}.") 

145 return super()._finish_apply(lhs, rhs) 

146 

147 def applied_columns(self, lhs: Relation, rhs: Relation) -> Set[ColumnTag]: 

148 # Docstring inherited. 

149 return lhs.columns | rhs.columns 

150 

151 def applied_min_rows(self, lhs: Relation, rhs: Relation) -> int: 

152 # Docstring inherited. 

153 return 0 

154 

155 def applied_max_rows(self, lhs: Relation, rhs: Relation) -> int | None: 

156 # Docstring inherited. 

157 if lhs.max_rows == 0 or rhs.max_rows == 0: 

158 return 0 

159 if lhs.max_rows is None or rhs.max_rows is None: 

160 return None 

161 else: 

162 return lhs.max_rows * rhs.max_rows 

163 

164 def applied_common_columns(self, lhs: Relation, rhs: Relation) -> frozenset[ColumnTag]: 

165 """Compute the actual common columns for a `Join` given its targets. 

166 

167 Parameters 

168 ---------- 

169 lhs : `Relation` 

170 One relation to join. 

171 rhs : `Relation` 

172 The other relation to join to ``lhs``. 

173 

174 Returns 

175 ------- 

176 common_columns : `~collections.abc.Set` [ `ColumnTag` ] 

177 Columns that are included in all of ``lhs.columns`` and 

178 ``rhs.columns`` and `max_columns`, checked to be a superset of 

179 `min_columns`. 

180 

181 Raises 

182 ------ 

183 ColumnError 

184 Raised if the result would not be a superset of `min_columns`. 

185 """ 

186 # Docstring inherited. 

187 if self.max_columns != self.min_columns: 

188 common_columns = {tag for tag in lhs.columns & rhs.columns if tag.is_key} 

189 if self.max_columns is not None: 

190 common_columns &= self.max_columns 

191 if not (common_columns >= self.min_columns): 

192 raise ColumnError( 

193 f"Common columns {common_columns} for join between {lhs} and {rhs} are not a superset " 

194 f"of the minimum columns {self.min_columns}." 

195 ) 

196 return frozenset(common_columns) 

197 else: 

198 return self.min_columns 

199 

200 def partial(self, fix: Relation, is_lhs: bool = False) -> PartialJoin: 

201 """Return a `UnaryOperation` that represents this join with one operand 

202 already provided and held fixed. 

203 

204 Parameters 

205 ---------- 

206 fix : `Relation` 

207 Relation to include in the returned unary operation. 

208 is_lhs : `bool`, optional 

209 Whether ``fix`` should be considered the ``lhs`` or ``rhs`` side of 

210 the join (`Join` side is *usually* irrelevant, but `Engine` 

211 implementations are permitted to make additional guarantees about 

212 row order or duplicates based on them). 

213 

214 Returns 

215 ------- 

216 partial_join : `PartialJoin` 

217 Unary operation representing a join to a fixed relation. 

218 

219 Raises 

220 ------ 

221 ColumnError 

222 Raised if the given predicate requires columns not present in 

223 ``lhs`` or ``rhs``. 

224 

225 Notes 

226 ----- 

227 This method and the class it returns are called "partial" in the spirit 

228 of `functools.partial`: a callable formed by holding some arguments to 

229 another callable fixed. 

230 """ 

231 if not (self.min_columns <= fix.columns): 

232 raise ColumnError( 

233 f"Missing columns {set(self.min_columns - fix.columns)} for partial join to {fix}." 

234 ) 

235 return PartialJoin(self, fix, is_lhs) 

236 

237 

238@final 

239@dataclasses.dataclass(frozen=True) 

240class PartialJoin(UnaryOperation): 

241 """A `UnaryOperation` that represents this join with one operand already 

242 provided and held fixed. 

243 

244 Notes 

245 ----- 

246 This class and the `Join.partial` used to construct it are called "partial" 

247 in the spirit of `functools.partial`: a callable formed by holding some 

248 arguments to another callable fixed. 

249 

250 `PartialJoin` instances never appear in relation trees; the `apply` method 

251 will return a `BinaryOperationRelation` with a `Join` operation instead of 

252 a `UnaryOperationRelation` with a `PartialJoin` (or one of the operands, if 

253 the other is a `join identity relation <Relation.is_join_identity>`). 

254 """ 

255 

256 binary: Join 

257 """The join operation (`Join`) to be applied. 

258 """ 

259 

260 fixed: Relation 

261 """The target relation already included in the operation (`Relation`). 

262 """ 

263 

264 fixed_is_lhs: bool 

265 """Whether `fixed` should be considered the ``lhs`` or ``rhs`` side of 

266 the join. 

267 

268 `Join` side is *usually* irrelevant, but `Engine` implementations are 

269 permitted to make additional guarantees about row order or duplicates based 

270 on them. 

271 """ 

272 

273 @property 

274 def columns_required(self) -> Set[ColumnTag]: 

275 # Docstring inherited. 

276 result = set(self.binary.predicate.columns_required) 

277 result.difference_update(self.fixed.columns) 

278 result.update(self.binary.min_columns) 

279 return result 

280 

281 @property 

282 def is_empty_invariant(self) -> bool: 

283 # Docstring inherited. 

284 return False 

285 

286 @property 

287 def is_count_invariant(self) -> bool: 

288 # Docstring inherited. 

289 return False 

290 

291 def __str__(self) -> str: 

292 return f"{self.binary!s}[{self.fixed!s}]" 

293 

294 def _begin_apply( 

295 self, target: Relation, preferred_engine: Engine | None 

296 ) -> tuple[UnaryOperation, Engine]: 

297 # Docstring inherited. 

298 if self.binary.max_columns != self.binary.min_columns: 

299 common_columns = self.binary.applied_common_columns(self.fixed, target) 

300 replacement = dataclasses.replace( 

301 self, 

302 binary=dataclasses.replace( 

303 self.binary, min_columns=common_columns, max_columns=common_columns 

304 ), 

305 ) 

306 return replacement._begin_apply(target, preferred_engine) 

307 if preferred_engine is None: 

308 preferred_engine = self.fixed.engine 

309 if not self.columns_required <= target.columns: 

310 raise ColumnError( 

311 f"Join {self} to relation {target} needs columns " 

312 f"{set(self.columns_required) - target.columns}." 

313 ) 

314 return super()._begin_apply(target, preferred_engine) 

315 

316 def _finish_apply(self, target: Relation) -> Relation: 

317 # Docstring inherited. 

318 if self.fixed_is_lhs: 

319 return self.binary.apply(self.fixed, target) 

320 else: 

321 return self.binary.apply(target, self.fixed) 

322 

323 def applied_columns(self, target: Relation) -> Set[ColumnTag]: 

324 # Docstring inherited. 

325 if self.fixed_is_lhs: 

326 return self.binary.applied_columns(self.fixed, target) 

327 else: 

328 return self.binary.applied_columns(target, self.fixed) 

329 

330 def applied_min_rows(self, target: Relation) -> int: 

331 # Docstring inherited. 

332 if self.fixed_is_lhs: 

333 return self.binary.applied_min_rows(self.fixed, target) 

334 else: 

335 return self.binary.applied_min_rows(target, self.fixed) 

336 

337 def applied_max_rows(self, target: Relation) -> int | None: 

338 # Docstring inherited. 

339 if self.fixed_is_lhs: 

340 return self.binary.applied_max_rows(self.fixed, target) 

341 else: 

342 return self.binary.applied_max_rows(target, self.fixed) 

343 

344 def commute(self, current: UnaryOperationRelation) -> UnaryCommutator: 

345 # Docstring inherited. 

346 from ._deduplication import Deduplication 

347 from ._projection import Projection 

348 

349 match current.operation: 

350 case Deduplication(): 

351 # A Join only commutes past Deduplication if the fixed relation 

352 # has unique rows, which is not something we can check right 

353 # now. 

354 return UnaryCommutator( 

355 first=None, 

356 second=current.operation, 

357 done=False, 

358 messages=("join-deduplication commutation is not supported",), 

359 ) 

360 case Projection(): 

361 # In order for projection(join(target)) to be equivalent to 

362 # join(projection(target)), the new outer projection has to 

363 # include the columns added by the join. Note that because we 

364 # require common_columns to be explicit at this point, the 

365 # projection cannot change them. 

366 return UnaryCommutator( 

367 first=self, 

368 second=Projection(frozenset(self.applied_columns(current))), 

369 ) 

370 case _: 

371 if not self.columns_required <= current.target.columns: 

372 return UnaryCommutator( 

373 first=None, 

374 second=current.operation, 

375 done=False, 

376 messages=( 

377 f"{current.target} is missing columns " 

378 f"{set(self.columns_required - current.target.columns)}", 

379 ), 

380 ) 

381 if current.operation.is_count_dependent: 

382 return UnaryCommutator( 

383 first=None, 

384 second=current.operation, 

385 done=False, 

386 messages=(f"{current.operation} is count-dependent",), 

387 ) 

388 return UnaryCommutator(first=self, second=current.operation)