Coverage for python / lsst / pipe / base / quantum_graph / aggregator / _communicators.py: 23%
367 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:59 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:59 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = (
31 "FatalWorkerError",
32 "IngesterCommunicator",
33 "ScannerCommunicator",
34 "SupervisorCommunicator",
35)
37import cProfile
38import dataclasses
39import enum
40import logging
41import os
42import signal
43import time
44import uuid
45from collections.abc import Iterable, Iterator
46from contextlib import ExitStack
47from traceback import format_exception
48from types import TracebackType
49from typing import Literal, Self, overload
51from lsst.utils.logging import LsstLogAdapter
53from .._provenance import ProvenanceQuantumScanData
54from ._config import AggregatorConfig
55from ._progress import ProgressManager, make_worker_log
56from ._structs import IngestRequest, ScanReport
57from ._workers import Event, Queue, Worker, WorkerFactory
59_TINY_TIMEOUT = 0.01
62class FatalWorkerError(BaseException):
63 """An exception raised by communicators when one worker (including the
64 supervisor) has caught an exception in order to signal the others to shut
65 down.
66 """
69class _WorkerCommunicationError(Exception):
70 """An exception raised by communicators when a worker has died unexpectedly
71 or become unresponsive.
72 """
75class _Sentinel(enum.Enum):
76 """Sentinel values used to indicate sequence points or worker shutdown
77 conditions.
78 """
80 NO_MORE_SCAN_REQUESTS = enum.auto()
81 """Sentinel sent from the supervisor to scanners to indicate that there are
82 no more quanta left to be scanned.
83 """
85 NO_MORE_INGEST_REQUESTS = enum.auto()
86 """Sentinel sent from scanners to the ingester to indicate that there are
87 will be no more ingest requests from a particular worker.
88 """
90 NO_MORE_WRITE_REQUESTS = enum.auto()
91 """Sentinel sent from scanners and the supervisor to the writer to
92 indicate that there are will be no more write requests from a particular
93 worker.
94 """
96 WRITE_REPORT = enum.auto()
97 """Sentinel sent from the writer to the supervisor to report that a
98 quantum's provenance was written.
99 """
102@dataclasses.dataclass
103class _WorkerErrorMessage:
104 """An internal worker used to pass information about an error that occurred
105 on a worker back to the supervisor.
107 As a rule, these are unexpected, unrecoverable exceptions.
108 """
110 worker: str
111 """Name of the originating worker."""
113 traceback: str
114 """A logged exception traceback.
116 Note that this is not a `BaseException` subclass that can actually be
117 re-raised on the supervisor; it's just something we can log to make the
118 right traceback appear on the screen. If something silences that printing
119 in favor of its own exception management (pytest!) this information
120 disappears.
121 """
124@dataclasses.dataclass
125class _ScanRequest:
126 """An internal struct passed from the supervisor to the scanners to request
127 a quantum be scanned.
128 """
130 quantum_id: uuid.UUID
131 """ID of the quantum to be scanned."""
134@dataclasses.dataclass
135class _IngestReport:
136 """An internal struct passed from the ingester to the supervisor to report
137 a completed ingest batch.
138 """
140 n_producers: int
141 """Number of producing quanta whose datasets were ingested.
143 We use quanta rather than datasets as the count here because the supervisor
144 knows the total number of quanta in advance but not the total number of
145 datasets to be ingested, so it's a lot easier to attach a denominator
146 and/or progress bar to this number.
147 """
150@dataclasses.dataclass
151class _WorkerDone:
152 """An internal struct passed from a worker to the supervisor when it has
153 successfully completed all work.
154 """
156 name: str
157 """Name of the worker reporting completion."""
160@dataclasses.dataclass
161class _ProgressLog:
162 """A high-level log message sent from a worker to the supervisor.
164 These are messages that should appear to come from the main
165 'aggregate-graph' logger, not a worker-specific one.
166 """
168 message: str
169 """Log message."""
171 level: int
172 """Log level."""
175@dataclasses.dataclass
176class _CompressionDictionary:
177 """An internal struct used to send the compression dictionary from the
178 writer to the scanners.
179 """
181 data: bytes
182 """The `bytes` representation of a `zstandard.ZstdCompressionDict`.
183 """
186type Report = (
187 ScanReport
188 | _IngestReport
189 | _WorkerErrorMessage
190 | _ProgressLog
191 | _WorkerDone
192 | Literal[_Sentinel.WRITE_REPORT]
193)
196def _disable_resources_parallelism() -> None:
197 os.environ["LSST_RESOURCES_NUM_WORKERS"] = "1"
198 os.environ.pop("LSST_RESOURCES_EXECUTOR", None)
199 os.environ["LSST_S3_USE_THREADS"] = "False"
202class SupervisorCommunicator:
203 """A helper object that lets the supervisor direct the other workers.
205 Parameters
206 ----------
207 log : `lsst.utils.logging.LsstLogAdapter`
208 LSST-customized logger.
209 n_scanners : `int`
210 Number of scanner workers.
211 worker_factory : `WorkerFactory`
212 Abstraction over threading vs. multiprocessing.
213 config : `AggregatorConfig`
214 Configuration for the aggregator.
215 """
217 def __init__(
218 self,
219 log: LsstLogAdapter,
220 n_scanners: int,
221 worker_factory: WorkerFactory,
222 config: AggregatorConfig,
223 ) -> None:
224 self.config = config
225 self.progress = ProgressManager(log, config)
226 self.n_scanners = n_scanners
227 # The supervisor sends scan requests to scanners on this queue.
228 # When complete, the supervisor sends n_scanners sentinals and each
229 # scanner is careful to only take one before it starts its shutdown.
230 self._scan_requests: Queue[_ScanRequest | Literal[_Sentinel.NO_MORE_SCAN_REQUESTS]] = (
231 worker_factory.make_queue()
232 )
233 # The scanners send ingest requests to the ingester on this queue. Each
234 # scanner sends one sentinal when it is done, and the ingester is
235 # careful to wait for n_scanners sentinals to arrive before it starts
236 # its shutdown.
237 self._ingest_requests: Queue[IngestRequest | Literal[_Sentinel.NO_MORE_INGEST_REQUESTS]] = (
238 worker_factory.make_queue()
239 )
240 # The scanners send write requests to the writer on this queue (which
241 # will be `None` if we're not writing). The supervisor also sends
242 # write requests for blocked quanta (which we don't scan). Each
243 # scanner and the supervisor send one sentinal when done, and the
244 # writer waits for (n_scanners + 1) sentinals to arrive before it
245 # starts its shutdown.
246 self._write_requests: (
247 Queue[ProvenanceQuantumScanData | Literal[_Sentinel.NO_MORE_WRITE_REQUESTS]] | None
248 ) = worker_factory.make_queue() if config.is_writing_provenance else None
249 # All other workers use this queue to send many different kinds of
250 # reports the supervisor. The supervisor waits for a _DONE sentinal
251 # from each worker before it finishes its shutdown.
252 self._reports: Queue[Report] = worker_factory.make_queue()
253 # The writer sends the compression dictionary to the scanners on this
254 # queue. It puts n_scanners copies on the queue, and each scanner only
255 # takes one. The compression_dict queue has no sentinal because it is
256 # only used at most once; the supervisor takes responsibility for
257 # clearing it out shutting down.
258 self._compression_dict: Queue[_CompressionDictionary] = worker_factory.make_queue()
259 # The supervisor sets this event when it receives an interrupt request
260 # from an exception in the main process (usually KeyboardInterrupt).
261 # Worker communicators check this in their polling loops and raise
262 # FatalWorkerError when they see it set.
263 self._cancel_event: Event = worker_factory.make_event()
264 # Track what state we are in closing down, so we can start at the right
265 # point if we're interrupted and __exit__ needs to clean up. Note that
266 # we can't rely on a non-exception __exit__ to do any shutdown work
267 # that might be slow, since a KeyboardInterrupt that occurs when
268 # __exit__ is already running can't be caught inside __exit__.
269 self._sent_no_more_scan_requests = False
270 self._sent_no_more_write_requests = False
271 self._n_scanners_done = 0
272 self.workers: dict[str, Worker] = {}
274 def _wait_for_workers_to_finish(self, already_failing: bool = False) -> None:
275 # Orderly shutdown, including exceptions: let workers clear out the
276 # queues they're responsible for reading from.
277 if not self._sent_no_more_scan_requests:
278 for _ in range(self.n_scanners):
279 self._scan_requests.put(_Sentinel.NO_MORE_SCAN_REQUESTS)
280 self._sent_no_more_scan_requests = True
281 if not self._sent_no_more_write_requests and self._write_requests is not None:
282 self._write_requests.put(_Sentinel.NO_MORE_WRITE_REQUESTS)
283 self._sent_no_more_write_requests = True
284 while not all(w.successful for w in self.workers.values()):
285 match self._handle_progress_reports(
286 self._get_report(block=True), already_failing=already_failing
287 ):
288 case None | ScanReport():
289 pass
290 case _WorkerDone(name=worker_name):
291 self.workers[worker_name].successful = True
292 if worker_name == IngesterCommunicator.get_worker_name():
293 self.progress.quantum_ingests.close()
294 elif worker_name == WriterCommunicator.get_worker_name():
295 self.progress.writes.close()
296 else:
297 self._n_scanners_done += 1
298 if self._n_scanners_done == self.n_scanners:
299 self.progress.scans.close()
300 case unexpected:
301 raise AssertionError(f"Unexpected message {unexpected!r} to supervisor.")
302 self.log.verbose(
303 "Waiting for workers [%s] to report successful completion.",
304 ", ".join(w.name for w in self.workers.values() if not w.successful),
305 )
306 self.log.verbose("Checking that all queues are empty.")
307 if self._scan_requests.clear():
308 self.progress.log.warning("Scan request queue was not empty at shutdown.")
309 self._scan_requests.kill()
310 if self._ingest_requests.clear():
311 self.progress.log.warning("Ingest request queue was not empty at shutdown.")
312 self._ingest_requests.kill()
313 if self._write_requests is not None and self._write_requests.clear():
314 self.progress.log.warning("Write request queue was not empty at shutdown.")
315 self._write_requests.kill()
316 if self._reports.clear():
317 self.progress.log.warning("Reports queue was not empty at shutdown.")
318 self._reports.kill()
319 if self._compression_dict.clear():
320 self.progress.log.warning("Compression dictionary queue was not empty at shutdown.")
321 self._compression_dict.kill()
322 for worker in self.workers.values():
323 self.log.verbose("Waiting for %s to shut down.", worker.name)
324 worker.join()
326 def _terminate(self) -> None:
327 # Disorderly shutdown: we cannot assume any of the
328 # multiprocessing.Queue object work, and in fact they may hang
329 # if we try to do anything with them.
330 self._scan_requests.kill()
331 self._ingest_requests.kill()
332 if self._write_requests is not None:
333 self._write_requests.kill()
334 self._compression_dict.kill()
335 self._reports.kill()
336 for name, worker in self.workers.items():
337 if worker.is_alive():
338 self.progress.log.critical("Terminating worker %r.", name)
339 worker.kill()
341 def __enter__(self) -> Self:
342 _disable_resources_parallelism()
343 self.progress.__enter__()
344 # We make the low-level logger in __enter__ instead of __init__ only
345 # because that's the pattern used by true workers (where it matters).
346 self.log = make_worker_log("supervisor", self.config)
347 return self
349 def __exit__(
350 self,
351 exc_type: type[BaseException] | None,
352 exc_value: BaseException | None,
353 traceback: TracebackType | None,
354 ) -> None:
355 if exc_type is not None:
356 self._cancel_event.set()
357 if exc_type is _WorkerCommunicationError:
358 self.progress.log.critical("Worker '%s' was terminated before it could finish.", exc_value)
359 self._terminate()
360 return None
361 if exc_type is not FatalWorkerError:
362 self.progress.log.critical("Caught %s; attempting to shut down cleanly.", exc_type)
363 try:
364 self._wait_for_workers_to_finish(already_failing=exc_type is not None)
365 except _WorkerCommunicationError as err:
366 self.progress.log.critical(
367 "Worker '%s' was terminated before it could finish (after scanning).", err
368 )
369 self._terminate()
370 raise
371 self.progress.__exit__(exc_type, exc_value, traceback)
372 return None
374 def request_scan(self, quantum_id: uuid.UUID) -> None:
375 """Send a request to the scanners to scan the given quantum.
377 Parameters
378 ----------
379 quantum_id : `uuid.UUID`
380 ID of the quantum to scan.
381 """
382 self._scan_requests.put(_ScanRequest(quantum_id))
384 def request_write(self, request: ProvenanceQuantumScanData) -> None:
385 """Send a request to the writer to write provenance for the given scan.
387 Parameters
388 ----------
389 request : `ProvenanceQuantumScanData`
390 Information from scanning a quantum (or knowing you don't have to,
391 in the case of blocked quanta).
392 """
393 assert self._write_requests is not None, "Writer should not be used if writing is disabled."
394 self._write_requests.put(request)
396 def poll(self) -> Iterator[ScanReport]:
397 """Poll for reports from workers while sending scan requests.
399 Yields
400 ------
401 scan_report : `ScanReport`
402 A report from a scanner that a quantum was scanned.
404 Notes
405 -----
406 This iterator blocks until the first scan report is received, and then
407 it continues until the report queue is empty.
408 """
409 block = True
410 while report := self._get_report(block=block):
411 match self._handle_progress_reports(report):
412 case ScanReport() as scan_report:
413 block = False
414 yield scan_report
415 case None:
416 pass
417 case unexpected:
418 raise AssertionError(f"Unexpected message {unexpected!r} to supervisor.")
420 @overload
421 def _get_report(self, block: Literal[True]) -> Report: ... 421 ↛ exitline 421 didn't return from function '_get_report' because
423 @overload
424 def _get_report(self, block: bool) -> Report | None: ... 424 ↛ exitline 424 didn't return from function '_get_report' because
426 def _get_report(self, block: bool) -> Report | None:
427 """Get a report from the reports queue, with timeout guards on
428 blocking requests.
430 This method may *return* WorkerCommunicatorError (rather than raise it)
431 when a serious error occurred communicating with a subprocess. This
432 is to avoid raising an exception in an __exit__ method (which calls
433 method).
434 """
435 report = self._reports.get(block=block, timeout=self.config.worker_check_timeout)
436 while report is None and block:
437 # We hit the timeout; make sure all of the workers
438 # that should be alive actually are.
439 for name, worker in self.workers.items():
440 if not worker.successful and not worker.is_alive():
441 # Delete this worker from the list of workers so we don't
442 # hit this condition again when we try to handle the
443 # exception we raise.
444 raise _WorkerCommunicationError(name)
445 # If nothing is dead and we didn't hit the hang timeout, keep
446 # trying.
447 report = self._reports.get(block=block, timeout=self.config.worker_check_timeout)
448 return report
450 def _handle_progress_reports(
451 self, report: Report, already_failing: bool = False
452 ) -> ScanReport | _WorkerDone | None:
453 """Handle reports to the supervisor that can appear at any time, and
454 are typically just updates to the progress we've made.
456 This includes:
458 - exceptions from workers (which raise `FatalWorkerError` here to
459 trigger ``__exit__``);
460 - ingest reports;
461 - write reports;
462 - progress logs.
464 If one of these is handled, `None` is returned; otherwise the original
465 report is returned.
466 """
467 match report:
468 case _WorkerErrorMessage(traceback=traceback, worker=worker):
469 self.progress.log.fatal("Exception raised on %s: \n%s", worker, traceback)
470 if not already_failing:
471 raise FatalWorkerError()
472 case _IngestReport(n_producers=n_producers):
473 self.progress.quantum_ingests.update(n_producers)
474 case _Sentinel.WRITE_REPORT:
475 self.progress.writes.update(1)
476 case _ProgressLog(message=message, level=level):
477 self.progress.log.log(level, "%s [after %0.1fs]", message, self.progress.elapsed_time)
478 case _:
479 return report
480 return None
483class WorkerCommunicator:
484 """A base class for non-supervisor worker communicators.
486 Parameters
487 ----------
488 supervisor : `SupervisorCommunicator`
489 Communicator for the supervisor to grab queues and information from.
490 name : `str`
491 Human-readable name for this worker.
493 Notes
494 -----
495 Each worker communicator is constructed in the main process and entered as
496 a context manager *only* on the actual worker process, so attributes that
497 cannot be pickled are constructed in ``__enter__`` instead of ``__init__``.
499 Worker communicators provide access to an `AggregatorConfig` and a logger
500 to their workers. As context managers, they handle exceptions and ensure
501 clean shutdowns, and since most workers need to use a lot of other context
502 managers (for file reading and writing, mostly), they provide an `enter`
503 method to keep every worker from also having to be a context manager just
504 to hold a context manager instance attribute.
506 Worker communicators can also be configured to record and dump profiling
507 information.
508 """
510 def __init__(self, supervisor: SupervisorCommunicator, name: str):
511 self.name = name
512 self.config = supervisor.config
513 self._reports = supervisor._reports
514 self._cancel_event = supervisor._cancel_event
516 def __enter__(self) -> Self:
517 _disable_resources_parallelism()
518 self.log = make_worker_log(self.name, self.config)
519 self.log.verbose("%s has PID %s (parent is %s).", self.name, os.getpid(), os.getppid())
520 self._exit_stack = ExitStack().__enter__()
521 if self.config.n_processes > 1:
522 # Multiprocessing: ignore interrupts so we can shut down cleanly.
523 signal.signal(signal.SIGINT, signal.SIG_IGN)
524 if self.config.worker_profile_dir is not None:
525 # We use time.time because we're interested in wall-clock time,
526 # not just CPU effort, since this is I/O-bound work.
527 self._profiler = cProfile.Profile(timer=time.time)
528 self._profiler.enable()
529 return self
531 def __exit__(
532 self,
533 exc_type: type[BaseException] | None,
534 exc_value: BaseException | None,
535 traceback: TracebackType | None,
536 ) -> bool | None:
537 if self.config.worker_profile_dir is not None and self.config.n_processes > 1:
538 self._profiler.disable()
539 os.makedirs(self.config.worker_profile_dir, exist_ok=True)
540 self._profiler.dump_stats(os.path.join(self.config.worker_profile_dir, f"{self.name}.profile"))
541 if exc_value is not None:
542 assert exc_type is not None, "Should be guaranteed by Python, but MyPy doesn't know that."
543 if exc_type is not FatalWorkerError:
544 self.log.warning("Error raised on this worker.", exc_info=(exc_type, exc_value, traceback))
545 assert exc_type is not None and traceback is not None
546 self._reports.put(
547 _WorkerErrorMessage(
548 self.name,
549 "".join(format_exception(exc_type, exc_value, traceback)),
550 )
551 )
552 self.log.debug("Error message sent to supervisor.")
553 else:
554 self.log.warning("Shutting down due to exception raised on another worker.")
555 self._exit_stack.__exit__(exc_type, exc_value, traceback)
556 return True
558 @property
559 def exit_stack(self) -> ExitStack:
560 """A `contextlib.ExitStack` tied to the communicator."""
561 return self._exit_stack
563 def log_progress(self, level: int, message: str) -> None:
564 """Send a high-level log message to the supervisor.
566 Parameters
567 ----------
568 level : `int`
569 Log level. Should be ``VERBOSE`` or higher.
570 message : `str`
571 Log message.
572 """
573 self._reports.put(_ProgressLog(message=message, level=level))
575 def check_for_cancel(self) -> None:
576 """Check for a cancel signal from the supervisor and raise
577 `FatalWorkerError` if it is present.
578 """
579 if self._cancel_event.is_set():
580 raise FatalWorkerError()
583class ScannerCommunicator(WorkerCommunicator):
584 """A communicator for scanner workers.
586 Parameters
587 ----------
588 supervisor : `SupervisorCommunicator`
589 Communicator for the supervisor to grab queues and information from.
590 scanner_id : `int`
591 Integer ID for this canner.
592 """
594 def __init__(self, supervisor: SupervisorCommunicator, scanner_id: int):
595 super().__init__(supervisor, self.get_worker_name(scanner_id))
596 self.scanner_id = scanner_id
597 self._scan_requests = supervisor._scan_requests
598 self._ingest_requests = supervisor._ingest_requests
599 self._write_requests = supervisor._write_requests
600 self._compression_dict = supervisor._compression_dict
601 self._got_no_more_scan_requests: bool = False
602 self._sent_no_more_ingest_requests: bool = False
604 @staticmethod
605 def get_worker_name(scanner_id: int) -> str:
606 return f"scanner-{scanner_id:03d}"
608 def report_scan(self, msg: ScanReport) -> None:
609 """Report a completed scan to the supervisor.
611 Parameters
612 ----------
613 msg : `ScanReport`
614 Report to send.
615 """
616 self._reports.put(msg)
618 def request_ingest(self, request: IngestRequest) -> None:
619 """Ask the ingester to ingest a quantum's outputs.
621 Parameters
622 ----------
623 request : `IngestRequest`
624 Description of the datasets to ingest.
626 Notes
627 -----
628 If this request has no datasets, this automatically reports the ingest
629 as complete to the supervisor instead of sending it to the ingester.
630 """
631 if request:
632 self._ingest_requests.put(request)
633 else:
634 self._reports.put(_IngestReport(1))
636 def request_write(self, request: ProvenanceQuantumScanData) -> None:
637 """Ask the writer to write provenance for a quantum.
639 Parameters
640 ----------
641 request : `ProvenanceQuantumScanData`
642 Result of scanning a quantum.
643 """
644 assert self._write_requests is not None, "Writer should not be used if writing is disabled."
645 self._write_requests.put(request)
647 def get_compression_dict(self) -> bytes | None:
648 """Attempt to get the compression dict from the writer.
650 Returns
651 -------
652 data : `bytes` or `None`
653 The `bytes` representation of the compression dictionary, or `None`
654 if the compression dictionary is not yet available.
656 Notes
657 -----
658 A scanner should only call this method before it actually has the
659 compression dict.
660 """
661 if (cdict := self._compression_dict.get()) is not None:
662 return cdict.data
663 return None
665 def poll(self) -> Iterator[uuid.UUID]:
666 """Poll for scan requests to process.
668 Yields
669 ------
670 quantum_id : `uuid.UUID`
671 ID of a new quantum to scan.
673 Notes
674 -----
675 This iterator ends when the supervisor reports that it is done
676 traversing the graph.
677 """
678 while True:
679 self.check_for_cancel()
680 scan_request = self._scan_requests.get(block=True, timeout=self.config.worker_sleep)
681 if scan_request is _Sentinel.NO_MORE_SCAN_REQUESTS:
682 self._got_no_more_scan_requests = True
683 return
684 if scan_request is not None:
685 yield scan_request.quantum_id
687 def __exit__(
688 self,
689 exc_type: type[BaseException] | None,
690 exc_value: BaseException | None,
691 traceback: TracebackType | None,
692 ) -> bool | None:
693 result = super().__exit__(exc_type, exc_value, traceback)
694 self._ingest_requests.put(_Sentinel.NO_MORE_INGEST_REQUESTS)
695 if self._write_requests is not None:
696 self._write_requests.put(_Sentinel.NO_MORE_WRITE_REQUESTS)
697 while not self._got_no_more_scan_requests:
698 if (
699 not self._got_no_more_scan_requests
700 and self._scan_requests.get(block=True) is _Sentinel.NO_MORE_SCAN_REQUESTS
701 ):
702 self._got_no_more_scan_requests = True
703 # We let the writer clear out the compression dict queue.
704 self.log.verbose("Sending completion message.")
705 self._reports.put(_WorkerDone(self.name))
706 return result
709class IngesterCommunicator(WorkerCommunicator):
710 """A communicator for the ingester worker.
712 Parameters
713 ----------
714 supervisor : `SupervisorCommunicator`
715 Communicator for the supervisor to grab queues and information from.
716 """
718 def __init__(self, supervisor: SupervisorCommunicator):
719 super().__init__(supervisor, self.get_worker_name())
720 self.n_scanners = supervisor.n_scanners
721 self._ingest_requests = supervisor._ingest_requests
722 self._n_requesters_done = 0
724 @staticmethod
725 def get_worker_name() -> str:
726 return "ingester"
728 def __exit__(
729 self,
730 exc_type: type[BaseException] | None,
731 exc_value: BaseException | None,
732 traceback: TracebackType | None,
733 ) -> bool | None:
734 result = super().__exit__(exc_type, exc_value, traceback)
735 while self._n_requesters_done != self.n_scanners:
736 self.log.debug(
737 "Waiting for %d requesters to be done (currently %d).",
738 self.n_scanners,
739 self._n_requesters_done,
740 )
741 if self._ingest_requests.get(block=True) is _Sentinel.NO_MORE_INGEST_REQUESTS:
742 self._n_requesters_done += 1
743 self.log.verbose("Sending completion message.")
744 self._reports.put(_WorkerDone(self.name))
745 return result
747 def report_ingest(self, n_producers: int) -> None:
748 """Report to the supervisor that an ingest batch was completed.
750 Parameters
751 ----------
752 n_producers : `int`
753 Number of producing quanta whose datasets were ingested.
754 """
755 self._reports.put(_IngestReport(n_producers))
757 def poll(self) -> Iterator[IngestRequest]:
758 """Poll for ingest requests from the scanner workers.
760 Yields
761 ------
762 request : `IngestRequest`
763 A request to ingest datasets produced by a single quantum.
765 Notes
766 -----
767 This iterator ends when all scanners indicate that they are done making
768 ingest requests.
769 """
770 while True:
771 self.check_for_cancel()
772 ingest_request = self._ingest_requests.get(block=True, timeout=_TINY_TIMEOUT)
773 if ingest_request is _Sentinel.NO_MORE_INGEST_REQUESTS:
774 self._n_requesters_done += 1
775 if self._n_requesters_done == self.n_scanners:
776 return
777 else:
778 continue
779 if ingest_request is not None:
780 yield ingest_request
783class WriterCommunicator(WorkerCommunicator):
784 """A communicator for the writer worker.
786 Parameters
787 ----------
788 supervisor : `SupervisorCommunicator`
789 Communicator for the supervisor to grab queues and information from.
790 """
792 def __init__(self, supervisor: SupervisorCommunicator):
793 assert supervisor._write_requests is not None
794 super().__init__(supervisor, self.get_worker_name())
795 self.n_scanners = supervisor.n_scanners
796 self._write_requests = supervisor._write_requests
797 self._compression_dict = supervisor._compression_dict
798 self._n_requesters = supervisor.n_scanners + 1
799 self._n_requesters_done = 0
800 self._sent_compression_dict = False
802 @staticmethod
803 def get_worker_name() -> str:
804 return "writer"
806 def __exit__(
807 self,
808 exc_type: type[BaseException] | None,
809 exc_value: BaseException | None,
810 traceback: TracebackType | None,
811 ) -> bool | None:
812 result = super().__exit__(exc_type, exc_value, traceback)
813 if exc_type is None:
814 self.log_progress(logging.INFO, "Provenance quantum graph written successfully.")
815 while self._n_requesters_done != self._n_requesters:
816 self.log.debug(
817 "Waiting for %d requesters to be done (currently %d).",
818 self._n_requesters,
819 self._n_requesters_done,
820 )
821 if self._write_requests.get(block=True) is _Sentinel.NO_MORE_WRITE_REQUESTS:
822 self._n_requesters_done += 1
823 if self._compression_dict.clear():
824 self.log.verbose("Cleared out compression dictionary queue.")
825 else:
826 self.log.verbose("Compression dictionary queue was already empty.")
827 self.log.verbose("Sending completion message.")
828 self._reports.put(_WorkerDone(self.name))
829 return result
831 def poll(self) -> Iterator[ProvenanceQuantumScanData]:
832 """Poll for writer requests from the scanner workers and supervisor.
834 Yields
835 ------
836 request : `ProvenanceQuantumScanData`
837 The result of a quantum scan.
839 Notes
840 -----
841 This iterator ends when all scanners and the supervisor indicate that
842 they are done making write requests.
843 """
844 while True:
845 self.check_for_cancel()
846 write_request = self._write_requests.get(block=True, timeout=_TINY_TIMEOUT)
847 if write_request is _Sentinel.NO_MORE_WRITE_REQUESTS:
848 self._n_requesters_done += 1
849 if self._n_requesters_done == self._n_requesters:
850 return
851 else:
852 continue
853 if write_request is not None:
854 yield write_request
856 def send_compression_dict(self, cdict_data: bytes) -> None:
857 """Send the compression dictionary to the scanners.
859 Parameters
860 ----------
861 cdict_data : `bytes`
862 The `bytes` representation of the compression dictionary.
863 """
864 self.log.debug("Sending compression dictionary.")
865 for _ in range(self.n_scanners):
866 self._compression_dict.put(_CompressionDictionary(cdict_data))
867 self._sent_compression_dict = True
869 def report_write(self) -> None:
870 """Report to the supervisor that provenance for a quantum was written
871 to the graph.
872 """
873 self._reports.put(_Sentinel.WRITE_REPORT)
875 def periodically_check_for_cancel[T](self, iterable: Iterable[T], n: int = 100) -> Iterator[T]:
876 """Iterate while checking for a cancellation signal every ``n``
877 iterations.
879 Parameters
880 ----------
881 iterable : `~collections.abc.Iterable`
882 Object to iterate over.
883 n : `int`
884 Check for cancellation every ``n`` iterations.
886 Returns
887 -------
888 iterator : `~collections.abc.Iterator`
889 Iterator.
890 """
891 i = 0
892 for entry in iterable:
893 yield entry
894 i += 1
895 if i % n == 0:
896 self.check_for_cancel()