Coverage for python / lsst / pipe / base / quantum_graph / aggregator / _supervisor.py: 29%

99 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 08:49 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ("aggregate_graph",) 

31 

32import dataclasses 

33import itertools 

34import uuid 

35 

36import astropy.units as u 

37import networkx 

38 

39from lsst.utils.logging import getLogger 

40from lsst.utils.usage import get_peak_mem_usage 

41 

42from ...graph_walker import GraphWalker 

43from ...pipeline_graph import TaskImportMode 

44from .._predicted import PredictedQuantumGraphComponents, PredictedQuantumGraphReader 

45from .._provenance import ProvenanceQuantumScanData, ProvenanceQuantumScanStatus 

46from ._communicators import ( 

47 IngesterCommunicator, 

48 ScannerCommunicator, 

49 SupervisorCommunicator, 

50 WriterCommunicator, 

51) 

52from ._config import AggregatorConfig 

53from ._ingester import Ingester 

54from ._scanner import Scanner 

55from ._structs import ScanReport 

56from ._workers import SpawnWorkerFactory, ThreadWorkerFactory 

57from ._writer import Writer 

58 

59 

60@dataclasses.dataclass 

61class Supervisor: 

62 """The main process/thread for the provenance aggregator.""" 

63 

64 predicted_path: str 

65 """Path to the predicted quantum graph.""" 

66 

67 comms: SupervisorCommunicator 

68 """Communicator object for the supervisor.""" 

69 

70 predicted: PredictedQuantumGraphComponents = dataclasses.field(init=False) 

71 """Components of the predicted quantum graph.""" 

72 

73 walker: GraphWalker[uuid.UUID] = dataclasses.field(init=False) 

74 """Iterator that traverses the quantum graph.""" 

75 

76 n_abandoned: int = 0 

77 """Number of quanta we abandoned because they did not complete in time and 

78 we could not assume they had failed. 

79 """ 

80 

81 def __post_init__(self) -> None: 

82 self.comms.progress.log.info("Reading predicted quantum graph.") 

83 with PredictedQuantumGraphReader.open( 

84 self.predicted_path, import_mode=TaskImportMode.DO_NOT_IMPORT 

85 ) as reader: 

86 reader.read_thin_graph() 

87 reader.read_init_quanta() 

88 self.predicted = reader.components 

89 self.comms.progress.log.info("Analyzing predicted graph.") 

90 xgraph = networkx.DiGraph(self.predicted.thin_graph.edges) 

91 # Make sure all quanta are in the graph, even if they don't have any 

92 # quantum-only edges. 

93 for thin_quantum in itertools.chain.from_iterable(self.predicted.thin_graph.quanta.values()): 

94 xgraph.add_node(thin_quantum.quantum_id) 

95 # Add init quanta as nodes without edges, because the scanner should 

96 # only be run after init outputs are all written and hence we don't 

97 # care when we process them. 

98 for init_quantum in self.predicted.init_quanta.root: 

99 xgraph.add_node(init_quantum.quantum_id) 

100 self.walker = GraphWalker(xgraph) 

101 

102 def loop(self) -> None: 

103 """Scan the outputs of the quantum graph to gather provenance and 

104 ingest outputs. 

105 """ 

106 n_quanta = self.predicted.header.n_quanta + len(self.predicted.init_quanta.root) 

107 self.comms.progress.scans.total = n_quanta 

108 self.comms.progress.writes.total = n_quanta 

109 self.comms.progress.quantum_ingests.total = n_quanta 

110 ready_set: set[uuid.UUID] = set() 

111 for ready_quanta in self.walker: 

112 self.comms.log.debug("Sending %d new quanta to scan queue.", len(ready_quanta)) 

113 ready_set.update(ready_quanta) 

114 while ready_set: 

115 self.comms.request_scan(ready_set.pop()) 

116 for scan_return in self.comms.poll(): 

117 self.handle_report(scan_return) 

118 if self.comms.config.incomplete: 

119 quantum_or_quanta = "quanta" if self.n_abandoned != 1 else "quantum" 

120 self.comms.progress.log.info( 

121 "%d %s incomplete/failed abandoned; re-run with incomplete=False to finish.", 

122 self.n_abandoned, 

123 quantum_or_quanta, 

124 ) 

125 self.comms.progress.log.info( 

126 "Scanning complete after %0.1fs; waiting for workers to finish.", 

127 self.comms.progress.elapsed_time, 

128 ) 

129 

130 def handle_report(self, scan_report: ScanReport) -> None: 

131 """Handle a report from a scanner. 

132 

133 Parameters 

134 ---------- 

135 scan_report : `ScanReport` 

136 Information about the scan. 

137 """ 

138 match scan_report.status: 

139 case ProvenanceQuantumScanStatus.SUCCESSFUL | ProvenanceQuantumScanStatus.INIT: 

140 self.comms.log.debug("Scan complete for %s: quantum succeeded.", scan_report.quantum_id) 

141 self.walker.finish(scan_report.quantum_id) 

142 case ProvenanceQuantumScanStatus.FAILED: 

143 self.comms.log.debug("Scan complete for %s: quantum failed.", scan_report.quantum_id) 

144 blocked_quanta = self.walker.fail(scan_report.quantum_id) 

145 for blocked_quantum_id in blocked_quanta: 

146 if self.comms.config.is_writing_provenance: 

147 self.comms.request_write( 

148 ProvenanceQuantumScanData( 

149 blocked_quantum_id, status=ProvenanceQuantumScanStatus.BLOCKED 

150 ) 

151 ) 

152 self.comms.progress.scans.update(1) 

153 self.comms.progress.quantum_ingests.update(len(blocked_quanta)) 

154 case ProvenanceQuantumScanStatus.ABANDONED: 

155 self.comms.log.debug("Abandoning scan for %s: quantum has not succeeded (yet).") 

156 self.walker.fail(scan_report.quantum_id) 

157 self.n_abandoned += 1 

158 case unexpected: 

159 raise AssertionError( 

160 f"Unexpected status {unexpected!r} in scanner loop for {scan_report.quantum_id}." 

161 ) 

162 self.comms.progress.scans.update(1) 

163 

164 

165def aggregate_graph(predicted_path: str, butler_path: str, config: AggregatorConfig) -> None: 

166 """Run the graph aggregator tool. 

167 

168 Parameters 

169 ---------- 

170 predicted_path : `str` 

171 Path to the predicted quantum graph. 

172 butler_path : `str` 

173 Path or alias to the central butler repository. 

174 config : `AggregatorConfig` 

175 Configuration for the aggregator. 

176 """ 

177 log = getLogger("lsst.pipe.base.quantum_graph.aggregator") 

178 worker_factory = ThreadWorkerFactory() if config.n_processes == 1 else SpawnWorkerFactory() 

179 with SupervisorCommunicator(log, config.n_processes, worker_factory, config) as comms: 

180 comms.progress.log.verbose("Starting workers.") 

181 if config.is_writing_provenance: 

182 writer_comms = WriterCommunicator(comms) 

183 comms.workers[writer_comms.name] = worker_factory.make_worker( 

184 target=Writer.run, 

185 args=(predicted_path, writer_comms), 

186 name=writer_comms.name, 

187 ) 

188 for scanner_id in range(config.n_processes): 

189 scanner_comms = ScannerCommunicator(comms, scanner_id) 

190 comms.workers[scanner_comms.name] = worker_factory.make_worker( 

191 target=Scanner.run, 

192 args=(predicted_path, butler_path, scanner_comms), 

193 name=scanner_comms.name, 

194 ) 

195 ingester_comms = IngesterCommunicator(comms) 

196 comms.workers[ingester_comms.name] = worker_factory.make_worker( 

197 target=Ingester.run, 

198 args=(predicted_path, butler_path, ingester_comms), 

199 name=ingester_comms.name, 

200 ) 

201 supervisor = Supervisor(predicted_path, comms) 

202 supervisor.loop() 

203 # We can't get memory usage for children until they've joined. 

204 parent_mem, child_mem = get_peak_mem_usage() 

205 # This is actually an upper bound on the peak (since the peaks could be 

206 # at different times), but since we expect memory usage to be more smooth 

207 # than spiky that's fine. 

208 total_mem: u.Quantity = parent_mem + child_mem 

209 log.info( 

210 "All aggregation tasks complete after %0.1fs; peak memory usage ≤ %0.1f MB.", 

211 comms.progress.elapsed_time, 

212 total_mem.to(u.MB).value, 

213 )