Coverage for python/lsst/daf/relation/iteration/_engine.py: 13%

148 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-21 09:39 +0000

1# This file is part of daf_relation. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("Engine",) 

25 

26import dataclasses 

27import itertools 

28from collections.abc import Callable, Container, Mapping, Sequence, Set 

29from operator import attrgetter, itemgetter 

30from typing import Any 

31 

32from .._columns import ( 

33 ColumnContainer, 

34 ColumnExpression, 

35 ColumnExpressionSequence, 

36 ColumnFunction, 

37 ColumnInContainer, 

38 ColumnLiteral, 

39 ColumnRangeLiteral, 

40 ColumnReference, 

41 ColumnTag, 

42 LogicalAnd, 

43 LogicalNot, 

44 LogicalOr, 

45 Predicate, 

46 PredicateFunction, 

47 PredicateLiteral, 

48 PredicateReference, 

49) 

50from .._engine import Engine as BaseEngine 

51from .._engine import GenericConcreteEngine 

52from .._exceptions import EngineError 

53from .._leaf_relation import LeafRelation 

54from .._marker_relation import MarkerRelation 

55from .._materialization import Materialization 

56from .._operation_relations import BinaryOperationRelation, UnaryOperationRelation 

57from .._operations import Calculation, Chain, Deduplication, Join, Projection, Selection, Slice, Sort 

58from .._relation import Relation 

59from .._transfer import Transfer 

60from .._unary_operation import UnaryOperation 

61from ._row_iterable import ( 

62 CalculationRowIterable, 

63 ChainRowIterable, 

64 MaterializedRowIterable, 

65 ProjectionRowIterable, 

66 RowIterable, 

67 RowMapping, 

68 RowSequence, 

69 SelectionRowIterable, 

70) 

71 

72 

73@dataclasses.dataclass(repr=False, eq=False, kw_only=True) 

74class Engine(GenericConcreteEngine[Callable[..., Any]]): 

75 """A concrete engine that treats relations as iterables with 

76 `~collections.abc.Mapping` rows. 

77 

78 See the `.iteration` module documentation for details. 

79 """ 

80 

81 name: str = "iteration" 

82 

83 def __repr__(self) -> str: 

84 return f"lsst.daf.relation.iteration.Engine({self.name!r})@{id(self):0x}" 

85 

86 def make_leaf( 

87 self, 

88 columns: Set[ColumnTag], 

89 payload: MaterializedRowIterable, 

90 *, 

91 name: str = "", 

92 messages: Sequence[str] = (), 

93 name_prefix: str = "leaf", 

94 parameters: Any = None, 

95 ) -> LeafRelation: 

96 """Create a nontrivial leaf relation in this engine. 

97 

98 This is a convenience method that simply forwards all arguments to 

99 the `.LeafRelation` constructor; see that class for details. 

100 """ 

101 return LeafRelation( 

102 self, 

103 frozenset(columns), 

104 payload, 

105 min_rows=len(payload), 

106 max_rows=len(payload), 

107 messages=messages, 

108 name=name, 

109 name_prefix=name_prefix, 

110 parameters=parameters, 

111 ) 

112 

113 def get_join_identity_payload(self) -> RowIterable: 

114 # Docstring inherited. 

115 return RowMapping((), {(): {}}) 

116 

117 def get_doomed_payload(self, columns: Set[ColumnTag]) -> RowIterable: 

118 # Docstring inherited. 

119 return RowMapping((), {}) 

120 

121 def backtrack_unary( 

122 self, operation: UnaryOperation, tree: Relation, preferred: BaseEngine 

123 ) -> tuple[Relation, bool, tuple[str, ...]]: 

124 # Docstring inherited. 

125 if tree.is_locked: 

126 return tree, False, (f"{tree} is locked",) 

127 match tree: 

128 case UnaryOperationRelation(target=target): 

129 commutator = operation.commute(tree) 

130 if commutator.first is None: 

131 return tree, commutator.done, commutator.messages 

132 else: 

133 upstream, done, messages = self.backtrack_unary(commutator.first, target, preferred) 

134 if upstream is not target: 

135 result = commutator.second._finish_apply(upstream) 

136 else: 

137 result = tree 

138 return ( 

139 result, 

140 done and commutator.done, 

141 commutator.messages + messages, 

142 ) 

143 case BinaryOperationRelation(): 

144 return tree, False, ("backtracking through binary operations is not implemented",) 

145 case Transfer(target=target) as transfer: 

146 if target.engine == preferred: 

147 return transfer.reapply(operation.apply(target)), True, () 

148 else: 

149 upstream, done, messages = target.engine.backtrack_unary(operation, target, preferred) 

150 return (transfer.reapply(upstream), done, messages) 

151 raise NotImplementedError(f"Unsupported relation type {tree} for engine {self}.") 

152 

153 def execute(self, relation: Relation) -> RowIterable: 

154 """Execute a native iteration relation, returning a Python iterable. 

155 

156 Parameters 

157 ---------- 

158 relation : `.Relation` 

159 Relation to execute. 

160 

161 Returns 

162 ------- 

163 rows : `RowIterable` 

164 Iterable over rows, with each row a mapping keyed by `.ColumnTag`. 

165 

166 Notes 

167 ----- 

168 This method does not in general iterate over the relation's rows; while 

169 some operations like `Sort` and `Deduplication` do require processing 

170 all rows up front (which will happen during a call to `execute`), most 

171 return lazy iterables that do little or nothing until actually iterated 

172 over. 

173 

174 This method requires all relations in the tree to have the same engine 

175 (``self``). Use the `.Processor` class to handle trees with multiple 

176 engines. 

177 """ 

178 if relation.engine != self: 

179 raise EngineError( 

180 f"Engine {self!r} cannot operate on relation {relation} with engine {relation.engine!r}. " 

181 "Use lsst.daf.relation.process to evaluate transfers first." 

182 ) 

183 if relation.max_rows == 0: 

184 return RowSequence([]) 

185 if relation.is_join_identity: 

186 return RowSequence([{}]) 

187 if (result := relation.payload) is not None: 

188 return result 

189 match relation: 

190 case UnaryOperationRelation(operation=operation, target=target): 

191 target_rows = self.execute(target) 

192 match operation: 

193 case Calculation(tag=tag, expression=expression): 

194 return CalculationRowIterable( 

195 target_rows, tag, self.convert_column_expression(expression) 

196 ) 

197 case Deduplication(): 

198 unique_key = tuple(tag for tag in relation.columns if tag.is_key) 

199 return target_rows.to_mapping(unique_key) 

200 case Projection(columns=columns): 

201 return ProjectionRowIterable(target_rows, columns) 

202 case Selection(predicate=predicate): 

203 return SelectionRowIterable(target_rows, self.convert_predicate(predicate)) 

204 case Slice(start=start, stop=stop): 

205 return target_rows.sliced(start, stop) 

206 case Sort(terms=terms): 

207 rows_list = list(target_rows) 

208 # Python's built-in sorting methods are stable, but 

209 # they don't provide a way to sort some keys in 

210 # ascending order and others in descending order. So 

211 # we split the sequence of order-by terms into groups 

212 # of consecutive terms with the same 

213 # ascending/descending state. At the same time, we 

214 # use the visitor to transform each expression into a 

215 # callable we can apply to each row. 

216 grouped_by_ascending = [ 

217 (ascending, [self.convert_column_expression(t.expression) for t in terms]) 

218 for ascending, terms in itertools.groupby(terms, key=attrgetter("ascending")) 

219 ] 

220 # Now we can sort the full list of rows once per group, 

221 # in reverse so sort terms at the start of the order-by 

222 # sequence "win". 

223 for ascending, callables in grouped_by_ascending[::-1]: 

224 rows_list.sort( 

225 key=lambda row: tuple(c(row) for c in callables), 

226 reverse=not ascending, 

227 ) 

228 return RowSequence(rows_list) 

229 case _: 

230 return self.apply_custom_unary_operation(operation, target) 

231 case BinaryOperationRelation(operation=operation, lhs=lhs, rhs=rhs): 

232 match operation: 

233 case Chain(): 

234 return ChainRowIterable([self.execute(lhs), self.execute(rhs)]) 

235 case Join(): 

236 raise EngineError("Joins are not supported by the iteration engine.") 

237 raise EngineError(f"Custom binary operation {operation} is not supported.") 

238 case Materialization(target=target): 

239 result = self.execute(target).materialized() 

240 relation.attach_payload(result) 

241 return result 

242 case Transfer(destination=destination, target=target): 

243 if isinstance(target.engine, Engine): 

244 # This is a transfer from another iteration engine 

245 # (maybe a subclass). We can handle that without 

246 # requiring a Processor, and that's useful for at 

247 # least unit testing (though maybe not anything 

248 # else; it's not clear why you'd have a transfer 

249 # between iteration engines otherwise). 

250 return target.engine.execute(target) 

251 raise EngineError( 

252 f"Engine {self!r} cannot handle transfer from " 

253 f"{target.engine!r} to {destination!r}; " 

254 "use `lsst.daf.relation.Processor` first to handle this operation." 

255 ) 

256 case MarkerRelation(target=target): 

257 return self.execute(target) 

258 raise AssertionError("matches should be exhaustive and all branches should return") 

259 

260 def convert_column_expression( 

261 self, expression: ColumnExpression 

262 ) -> Callable[[Mapping[ColumnTag, Any]], Any]: 

263 """Convert a `.ColumnExpression` to a Python callable. 

264 

265 Parameters 

266 ---------- 

267 expression : `.ColumnExpression` 

268 Expression to convert. 

269 

270 Returns 

271 ------- 

272 callable 

273 Callable that takes a single `~collections.abc.Mapping` argument 

274 (with `.ColumnTag` keys and regular Python values, representing a 

275 row in a relation), returning the evaluated expression as another 

276 regular Python value. 

277 """ 

278 match expression: 

279 case ColumnLiteral(value=value): 

280 return lambda row: value 

281 case ColumnReference(tag=tag): 

282 return itemgetter(tag) 

283 case ColumnFunction(name=name, args=args): 

284 function = self.get_function(name) 

285 if function is not None: 

286 arg_callables = [self.convert_column_expression(arg) for arg in args] 

287 # MyPy doesn't see 'function' as not-None for some reason. 

288 return lambda row: function(*[c(row) for c in arg_callables]) # type: ignore 

289 first, *rest = (self.convert_column_expression(arg) for arg in args) 

290 return lambda row: getattr(first(row), name)(*[r(row) for r in rest]) 

291 raise AssertionError("matches should be exhaustive and all branches should return") 

292 

293 def convert_column_container( 

294 self, expression: ColumnContainer 

295 ) -> Callable[[Mapping[ColumnTag, Any]], Container]: 

296 """Convert a `.ColumnContainer` to a Python callable. 

297 

298 Parameters 

299 ---------- 

300 expression : `.ColumnContainer` 

301 Expression to convert. 

302 

303 Returns 

304 ------- 

305 callable 

306 Callable that takes a single `~collections.abc.Mapping` argument 

307 (with `.ColumnTag` keys and regular Python values, representing a 

308 row in a relation), returning the evaluated expression as 

309 `collections.abc.Container` instance. 

310 """ 

311 match expression: 

312 case ColumnRangeLiteral(value=value): 

313 return lambda row: value 

314 case ColumnExpressionSequence(items=items): 

315 item_callables = [self.convert_column_expression(item) for item in items] 

316 return lambda row: {c(row) for c in item_callables} 

317 raise AssertionError("matches should be exhaustive and all branches should return") 

318 

319 def convert_predicate(self, predicate: Predicate) -> Callable[[Mapping[ColumnTag, Any]], bool]: 

320 """Convert a `Predicate` to a Python callable. 

321 

322 Parameters 

323 ---------- 

324 predicate : `Predicate` 

325 Expression to convert. 

326 

327 Returns 

328 ------- 

329 callable 

330 Callable that takes a single `~collections.abc.Mapping` argument 

331 (with `.ColumnTag` keys and regular Python values, representing a 

332 row in a relation), returning the evaluated expression as a `bool`. 

333 """ 

334 match predicate: 

335 case PredicateFunction(name=name, args=args): 

336 function = self.get_function(name) 

337 if function is not None: 

338 arg_callables = [self.convert_column_expression(arg) for arg in args] 

339 # MyPy doesn't see 'function' as not-None in the capture 

340 # for some reason. 

341 return lambda row: function(*[c(row) for c in arg_callables]) # type: ignore 

342 first, *rest = (self.convert_column_expression(arg) for arg in args) 

343 return lambda row: getattr(first(row), name)(*[r(row) for r in rest]) 

344 case LogicalAnd(operands=operands): 

345 operand_callables = [self.convert_predicate(arg) for arg in operands] 

346 return lambda row: all(c(row) for c in operand_callables) 

347 case LogicalOr(operands=operands): 

348 operand_callables = [self.convert_predicate(arg) for arg in operands] 

349 return lambda row: any(c(row) for c in operand_callables) 

350 case LogicalNot(operand=operand): 

351 target_callable = self.convert_predicate(operand) 

352 return lambda row: not target_callable(row) 

353 case PredicateReference(tag=tag): 

354 return itemgetter(tag) 

355 case PredicateLiteral(value=value): 

356 return lambda row: value 

357 case ColumnInContainer(item=item, container=container): 

358 item_callable = self.convert_column_expression(item) 

359 container_callable = self.convert_column_container(container) 

360 return lambda row: item_callable(row) in container_callable(row) 

361 raise AssertionError("matches should be exhaustive and all branches should return") 

362 

363 def apply_custom_unary_operation(self, operation: UnaryOperation, target: Relation) -> RowIterable: 

364 """Convert a custom `.UnaryOperation` to a `RowIterable`. 

365 

366 This method must be implemented in a subclass engine in order to 

367 support any custom `.UnaryOperation`. 

368 

369 Parameters 

370 ---------- 

371 operation : `.UnaryOperation` 

372 Operation to apply. Guaranteed to be a `.Marker`, `.Reordering`, 

373 or `.RowFilter` subclass. 

374 target : `.Relation` 

375 Target of the unary operation. Typically this will be passed to 

376 `execute` and the result used to construct a new `RowIterable`. 

377 

378 Returns 

379 ------- 

380 rows : `RowIterable` 

381 Iterable over rows, with each row a mapping keyed by `.ColumnTag`. 

382 """ 

383 raise EngineError(f"Custom operation {operation} not supported by engine {self}.")