Coverage for python/lsst/ctrl/mpexec/execFixupDataId.py: 22%

41 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-01 09:30 +0000

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22__all__ = ["ExecutionGraphFixup"] 

23 

24import contextlib 

25import itertools 

26from collections import defaultdict 

27from collections.abc import Sequence 

28from typing import Any 

29 

30import networkx as nx 

31from lsst.pipe.base import QuantumGraph, QuantumNode 

32 

33from .executionGraphFixup import ExecutionGraphFixup 

34 

35 

36class ExecFixupDataId(ExecutionGraphFixup): 

37 """Implementation of ExecutionGraphFixup for ordering of tasks based 

38 on DataId values. 

39 

40 This class is a trivial implementation mostly useful as an example, 

41 though it can be used to make actual fixup instances by defining 

42 a method that instantiates it, e.g.:: 

43 

44 # lsst/ap/verify/ci_fixup.py 

45 

46 from lsst.ctrl.mpexec.execFixupDataId import ExecFixupDataId 

47 

48 def assoc_fixup(): 

49 return ExecFixupDataId(taskLabel="ap_assoc", 

50 dimensions=("visit", "detector")) 

51 

52 

53 and then executing pipetask:: 

54 

55 pipetask run --graph-fixup=lsst.ap.verify.ci_fixup.assoc_fixup ... 

56 

57 This will add new dependencies between quanta executed by the task with 

58 label "ap_assoc". Quanta with higher visit number will depend on quanta 

59 with lower visit number and their execution will wait until lower visit 

60 number finishes. 

61 

62 Parameters 

63 ---------- 

64 taskLabel : `str` 

65 The label of the task for which to add dependencies. 

66 dimensions : `str` or sequence [`str`] 

67 One or more dimension names, quanta execution will be ordered 

68 according to values of these dimensions. 

69 reverse : `bool`, optional 

70 If `False` (default) then quanta with higher values of dimensions 

71 will be executed after quanta with lower values, otherwise the order 

72 is reversed. 

73 """ 

74 

75 def __init__(self, taskLabel: str, dimensions: str | Sequence[str], reverse: bool = False): 

76 self.taskLabel = taskLabel 

77 self.dimensions = dimensions 

78 self.reverse = reverse 

79 if isinstance(self.dimensions, str): 

80 self.dimensions = (self.dimensions,) 

81 else: 

82 self.dimensions = tuple(self.dimensions) 

83 

84 def _key(self, qnode: QuantumNode) -> tuple[Any, ...]: 

85 """Produce comparison key for quantum data. 

86 

87 Parameters 

88 ---------- 

89 qnode : `QuantumNode` 

90 An individual node in a `~lsst.pipe.base.QuantumGraph` 

91 

92 Returns 

93 ------- 

94 key : `tuple` 

95 """ 

96 dataId = qnode.quantum.dataId 

97 assert dataId is not None, "Quantum DataId cannot be None" 

98 key = tuple(dataId[dim] for dim in self.dimensions) 

99 return key 

100 

101 def fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph: 

102 taskDef = graph.findTaskDefByLabel(self.taskLabel) 

103 if taskDef is None: 

104 raise ValueError(f"Cannot find task with label {self.taskLabel}") 

105 quanta = list(graph.getNodesForTask(taskDef)) 

106 keyQuanta = defaultdict(list) 

107 for q in quanta: 

108 key = self._key(q) 

109 keyQuanta[key].append(q) 

110 keys = sorted(keyQuanta.keys(), reverse=self.reverse) 

111 networkGraph = graph.graph 

112 

113 for prev_key, key in itertools.pairwise(keys): 

114 for prev_node in keyQuanta[prev_key]: 

115 for node in keyQuanta[key]: 

116 # remove any existing edges between the two nodes, but 

117 # don't fail if there are not any. Both directions need 

118 # tried because in a directed graph, order maters 

119 for edge in ((node, prev_node), (prev_node, node)): 

120 with contextlib.suppress(nx.NetworkXException): 

121 networkGraph.remove_edge(*edge) 

122 

123 networkGraph.add_edge(prev_node, node) 

124 return graph