Coverage for python / lsst / pipe / base / quantum_graph / aggregator / _supervisor.py: 29%
99 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:47 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ("aggregate_graph",)
32import dataclasses
33import itertools
34import uuid
36import astropy.units as u
37import networkx
39from lsst.utils.logging import getLogger
40from lsst.utils.usage import get_peak_mem_usage
42from ...graph_walker import GraphWalker
43from ...pipeline_graph import TaskImportMode
44from .._predicted import PredictedQuantumGraphComponents, PredictedQuantumGraphReader
45from .._provenance import ProvenanceQuantumScanData, ProvenanceQuantumScanStatus
46from ._communicators import (
47 IngesterCommunicator,
48 ScannerCommunicator,
49 SupervisorCommunicator,
50 WriterCommunicator,
51)
52from ._config import AggregatorConfig
53from ._ingester import Ingester
54from ._scanner import Scanner
55from ._structs import ScanReport
56from ._workers import SpawnWorkerFactory, ThreadWorkerFactory
57from ._writer import Writer
60@dataclasses.dataclass
61class Supervisor:
62 """The main process/thread for the provenance aggregator."""
64 predicted_path: str
65 """Path to the predicted quantum graph."""
67 comms: SupervisorCommunicator
68 """Communicator object for the supervisor."""
70 predicted: PredictedQuantumGraphComponents = dataclasses.field(init=False)
71 """Components of the predicted quantum graph."""
73 walker: GraphWalker[uuid.UUID] = dataclasses.field(init=False)
74 """Iterator that traverses the quantum graph."""
76 n_abandoned: int = 0
77 """Number of quanta we abandoned because they did not complete in time and
78 we could not assume they had failed.
79 """
81 def __post_init__(self) -> None:
82 self.comms.progress.log.info("Reading predicted quantum graph.")
83 with PredictedQuantumGraphReader.open(
84 self.predicted_path, import_mode=TaskImportMode.DO_NOT_IMPORT
85 ) as reader:
86 reader.read_thin_graph()
87 reader.read_init_quanta()
88 self.predicted = reader.components
89 self.comms.progress.log.info("Analyzing predicted graph.")
90 xgraph = networkx.DiGraph(self.predicted.thin_graph.edges)
91 # Make sure all quanta are in the graph, even if they don't have any
92 # quantum-only edges.
93 for thin_quantum in itertools.chain.from_iterable(self.predicted.thin_graph.quanta.values()):
94 xgraph.add_node(thin_quantum.quantum_id)
95 # Add init quanta as nodes without edges, because the scanner should
96 # only be run after init outputs are all written and hence we don't
97 # care when we process them.
98 for init_quantum in self.predicted.init_quanta.root:
99 xgraph.add_node(init_quantum.quantum_id)
100 self.walker = GraphWalker(xgraph)
102 def loop(self) -> None:
103 """Scan the outputs of the quantum graph to gather provenance and
104 ingest outputs.
105 """
106 n_quanta = self.predicted.header.n_quanta + len(self.predicted.init_quanta.root)
107 self.comms.progress.scans.total = n_quanta
108 self.comms.progress.writes.total = n_quanta
109 self.comms.progress.quantum_ingests.total = n_quanta
110 ready_set: set[uuid.UUID] = set()
111 for ready_quanta in self.walker:
112 self.comms.log.debug("Sending %d new quanta to scan queue.", len(ready_quanta))
113 ready_set.update(ready_quanta)
114 while ready_set:
115 self.comms.request_scan(ready_set.pop())
116 for scan_return in self.comms.poll():
117 self.handle_report(scan_return)
118 if self.comms.config.incomplete:
119 quantum_or_quanta = "quanta" if self.n_abandoned != 1 else "quantum"
120 self.comms.progress.log.info(
121 "%d %s incomplete/failed abandoned; re-run with incomplete=False to finish.",
122 self.n_abandoned,
123 quantum_or_quanta,
124 )
125 self.comms.progress.log.info(
126 "Scanning complete after %0.1fs; waiting for workers to finish.",
127 self.comms.progress.elapsed_time,
128 )
130 def handle_report(self, scan_report: ScanReport) -> None:
131 """Handle a report from a scanner.
133 Parameters
134 ----------
135 scan_report : `ScanReport`
136 Information about the scan.
137 """
138 match scan_report.status:
139 case ProvenanceQuantumScanStatus.SUCCESSFUL | ProvenanceQuantumScanStatus.INIT:
140 self.comms.log.debug("Scan complete for %s: quantum succeeded.", scan_report.quantum_id)
141 self.walker.finish(scan_report.quantum_id)
142 case ProvenanceQuantumScanStatus.FAILED:
143 self.comms.log.debug("Scan complete for %s: quantum failed.", scan_report.quantum_id)
144 blocked_quanta = self.walker.fail(scan_report.quantum_id)
145 for blocked_quantum_id in blocked_quanta:
146 if self.comms.config.is_writing_provenance:
147 self.comms.request_write(
148 ProvenanceQuantumScanData(
149 blocked_quantum_id, status=ProvenanceQuantumScanStatus.BLOCKED
150 )
151 )
152 self.comms.progress.scans.update(1)
153 self.comms.progress.quantum_ingests.update(len(blocked_quanta))
154 case ProvenanceQuantumScanStatus.ABANDONED:
155 self.comms.log.debug("Abandoning scan for %s: quantum has not succeeded (yet).")
156 self.walker.fail(scan_report.quantum_id)
157 self.n_abandoned += 1
158 case unexpected:
159 raise AssertionError(
160 f"Unexpected status {unexpected!r} in scanner loop for {scan_report.quantum_id}."
161 )
162 self.comms.progress.scans.update(1)
165def aggregate_graph(predicted_path: str, butler_path: str, config: AggregatorConfig) -> None:
166 """Run the graph aggregator tool.
168 Parameters
169 ----------
170 predicted_path : `str`
171 Path to the predicted quantum graph.
172 butler_path : `str`
173 Path or alias to the central butler repository.
174 config : `AggregatorConfig`
175 Configuration for the aggregator.
176 """
177 log = getLogger("lsst.pipe.base.quantum_graph.aggregator")
178 worker_factory = ThreadWorkerFactory() if config.n_processes == 1 else SpawnWorkerFactory()
179 with SupervisorCommunicator(log, config.n_processes, worker_factory, config) as comms:
180 comms.progress.log.verbose("Starting workers.")
181 if config.is_writing_provenance:
182 writer_comms = WriterCommunicator(comms)
183 comms.workers[writer_comms.name] = worker_factory.make_worker(
184 target=Writer.run,
185 args=(predicted_path, writer_comms),
186 name=writer_comms.name,
187 )
188 for scanner_id in range(config.n_processes):
189 scanner_comms = ScannerCommunicator(comms, scanner_id)
190 comms.workers[scanner_comms.name] = worker_factory.make_worker(
191 target=Scanner.run,
192 args=(predicted_path, butler_path, scanner_comms),
193 name=scanner_comms.name,
194 )
195 ingester_comms = IngesterCommunicator(comms)
196 comms.workers[ingester_comms.name] = worker_factory.make_worker(
197 target=Ingester.run,
198 args=(predicted_path, butler_path, ingester_comms),
199 name=ingester_comms.name,
200 )
201 supervisor = Supervisor(predicted_path, comms)
202 supervisor.loop()
203 # We can't get memory usage for children until they've joined.
204 parent_mem, child_mem = get_peak_mem_usage()
205 # This is actually an upper bound on the peak (since the peaks could be
206 # at different times), but since we expect memory usage to be more smooth
207 # than spiky that's fine.
208 total_mem: u.Quantity = parent_mem + child_mem
209 log.info(
210 "All aggregation tasks complete after %0.1fs; peak memory usage ≤ %0.1f MB.",
211 comms.progress.elapsed_time,
212 total_mem.to(u.MB).value,
213 )