Coverage for python/lsst/ctrl/mpexec/execFixupDataId.py: 19%

41 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-14 19:56 +0000

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22__all__ = ["ExecutionGraphFixup"] 

23 

24from collections import defaultdict 

25from collections.abc import Sequence 

26from typing import Any 

27 

28import networkx as nx 

29from lsst.pipe.base import QuantumGraph, QuantumNode 

30 

31from .executionGraphFixup import ExecutionGraphFixup 

32 

33 

34class ExecFixupDataId(ExecutionGraphFixup): 

35 """Implementation of ExecutionGraphFixup for ordering of tasks based 

36 on DataId values. 

37 

38 This class is a trivial implementation mostly useful as an example, 

39 though it can be used to make actual fixup instances by defining 

40 a method that instantiates it, e.g.:: 

41 

42 # lsst/ap/verify/ci_fixup.py 

43 

44 from lsst.ctrl.mpexec.execFixupDataId import ExecFixupDataId 

45 

46 def assoc_fixup(): 

47 return ExecFixupDataId(taskLabel="ap_assoc", 

48 dimensions=("visit", "detector")) 

49 

50 

51 and then executing pipetask:: 

52 

53 pipetask run --graph-fixup=lsst.ap.verify.ci_fixup.assoc_fixup ... 

54 

55 This will add new dependencies between quanta executed by the task with 

56 label "ap_assoc". Quanta with higher visit number will depend on quanta 

57 with lower visit number and their execution will wait until lower visit 

58 number finishes. 

59 

60 Parameters 

61 ---------- 

62 taskLabel : `str` 

63 The label of the task for which to add dependencies. 

64 dimensions : `str` or sequence [`str`] 

65 One or more dimension names, quanta execution will be ordered 

66 according to values of these dimensions. 

67 reverse : `bool`, optional 

68 If `False` (default) then quanta with higher values of dimensions 

69 will be executed after quanta with lower values, otherwise the order 

70 is reversed. 

71 """ 

72 

73 def __init__(self, taskLabel: str, dimensions: str | Sequence[str], reverse: bool = False): 

74 self.taskLabel = taskLabel 

75 self.dimensions = dimensions 

76 self.reverse = reverse 

77 if isinstance(self.dimensions, str): 

78 self.dimensions = (self.dimensions,) 

79 else: 

80 self.dimensions = tuple(self.dimensions) 

81 

82 def _key(self, qnode: QuantumNode) -> tuple[Any, ...]: 

83 """Produce comparison key for quantum data. 

84 

85 Parameters 

86 ---------- 

87 qnode : `QuantumNode` 

88 An individual node in a `~lsst.pipe.base.QuantumGraph` 

89 

90 Returns 

91 ------- 

92 key : `tuple` 

93 """ 

94 dataId = qnode.quantum.dataId 

95 assert dataId is not None, "Quantum DataId cannot be None" 

96 key = tuple(dataId[dim] for dim in self.dimensions) 

97 return key 

98 

99 def fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph: 

100 taskDef = graph.findTaskDefByLabel(self.taskLabel) 

101 if taskDef is None: 

102 raise ValueError(f"Cannot find task with label {self.taskLabel}") 

103 quanta = list(graph.getNodesForTask(taskDef)) 

104 keyQuanta = defaultdict(list) 

105 for q in quanta: 

106 key = self._key(q) 

107 keyQuanta[key].append(q) 

108 keys = sorted(keyQuanta.keys(), reverse=self.reverse) 

109 networkGraph = graph.graph 

110 

111 for prev_key, key in zip(keys, keys[1:]): 

112 for prev_node in keyQuanta[prev_key]: 

113 for node in keyQuanta[key]: 

114 # remove any existing edges between the two nodes, but 

115 # don't fail if there are not any. Both directions need 

116 # tried because in a directed graph, order maters 

117 for edge in ((node, prev_node), (prev_node, node)): 

118 try: 

119 networkGraph.remove_edge(*edge) 

120 except nx.NetworkXException: 

121 pass 

122 networkGraph.add_edge(prev_node, node) 

123 return graph