Coverage for python / lsst / pipe / base / quantum_graph / aggregator / _communicators.py: 23%

367 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 08:57 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ( 

31 "FatalWorkerError", 

32 "IngesterCommunicator", 

33 "ScannerCommunicator", 

34 "SupervisorCommunicator", 

35) 

36 

37import cProfile 

38import dataclasses 

39import enum 

40import logging 

41import os 

42import signal 

43import time 

44import uuid 

45from collections.abc import Iterable, Iterator 

46from contextlib import ExitStack 

47from traceback import format_exception 

48from types import TracebackType 

49from typing import Literal, Self, overload 

50 

51from lsst.utils.logging import LsstLogAdapter 

52 

53from .._provenance import ProvenanceQuantumScanData 

54from ._config import AggregatorConfig 

55from ._progress import ProgressManager, make_worker_log 

56from ._structs import IngestRequest, ScanReport 

57from ._workers import Event, Queue, Worker, WorkerFactory 

58 

59_TINY_TIMEOUT = 0.01 

60 

61 

62class FatalWorkerError(BaseException): 

63 """An exception raised by communicators when one worker (including the 

64 supervisor) has caught an exception in order to signal the others to shut 

65 down. 

66 """ 

67 

68 

69class _WorkerCommunicationError(Exception): 

70 """An exception raised by communicators when a worker has died unexpectedly 

71 or become unresponsive. 

72 """ 

73 

74 

75class _Sentinel(enum.Enum): 

76 """Sentinel values used to indicate sequence points or worker shutdown 

77 conditions. 

78 """ 

79 

80 NO_MORE_SCAN_REQUESTS = enum.auto() 

81 """Sentinel sent from the supervisor to scanners to indicate that there are 

82 no more quanta left to be scanned. 

83 """ 

84 

85 NO_MORE_INGEST_REQUESTS = enum.auto() 

86 """Sentinel sent from scanners to the ingester to indicate that there are 

87 will be no more ingest requests from a particular worker. 

88 """ 

89 

90 NO_MORE_WRITE_REQUESTS = enum.auto() 

91 """Sentinel sent from scanners and the supervisor to the writer to 

92 indicate that there are will be no more write requests from a particular 

93 worker. 

94 """ 

95 

96 WRITE_REPORT = enum.auto() 

97 """Sentinel sent from the writer to the supervisor to report that a 

98 quantum's provenance was written. 

99 """ 

100 

101 

102@dataclasses.dataclass 

103class _WorkerErrorMessage: 

104 """An internal worker used to pass information about an error that occurred 

105 on a worker back to the supervisor. 

106 

107 As a rule, these are unexpected, unrecoverable exceptions. 

108 """ 

109 

110 worker: str 

111 """Name of the originating worker.""" 

112 

113 traceback: str 

114 """A logged exception traceback. 

115 

116 Note that this is not a `BaseException` subclass that can actually be 

117 re-raised on the supervisor; it's just something we can log to make the 

118 right traceback appear on the screen. If something silences that printing 

119 in favor of its own exception management (pytest!) this information 

120 disappears. 

121 """ 

122 

123 

124@dataclasses.dataclass 

125class _ScanRequest: 

126 """An internal struct passed from the supervisor to the scanners to request 

127 a quantum be scanned. 

128 """ 

129 

130 quantum_id: uuid.UUID 

131 """ID of the quantum to be scanned.""" 

132 

133 

134@dataclasses.dataclass 

135class _IngestReport: 

136 """An internal struct passed from the ingester to the supervisor to report 

137 a completed ingest batch. 

138 """ 

139 

140 n_producers: int 

141 """Number of producing quanta whose datasets were ingested. 

142 

143 We use quanta rather than datasets as the count here because the supervisor 

144 knows the total number of quanta in advance but not the total number of 

145 datasets to be ingested, so it's a lot easier to attach a denominator 

146 and/or progress bar to this number. 

147 """ 

148 

149 

150@dataclasses.dataclass 

151class _WorkerDone: 

152 """An internal struct passed from a worker to the supervisor when it has 

153 successfully completed all work. 

154 """ 

155 

156 name: str 

157 """Name of the worker reporting completion.""" 

158 

159 

160@dataclasses.dataclass 

161class _ProgressLog: 

162 """A high-level log message sent from a worker to the supervisor. 

163 

164 These are messages that should appear to come from the main 

165 'aggregate-graph' logger, not a worker-specific one. 

166 """ 

167 

168 message: str 

169 """Log message.""" 

170 

171 level: int 

172 """Log level.""" 

173 

174 

175@dataclasses.dataclass 

176class _CompressionDictionary: 

177 """An internal struct used to send the compression dictionary from the 

178 writer to the scanners. 

179 """ 

180 

181 data: bytes 

182 """The `bytes` representation of a `zstandard.ZstdCompressionDict`. 

183 """ 

184 

185 

186type Report = ( 

187 ScanReport 

188 | _IngestReport 

189 | _WorkerErrorMessage 

190 | _ProgressLog 

191 | _WorkerDone 

192 | Literal[_Sentinel.WRITE_REPORT] 

193) 

194 

195 

196def _disable_resources_parallelism() -> None: 

197 os.environ["LSST_RESOURCES_NUM_WORKERS"] = "1" 

198 os.environ.pop("LSST_RESOURCES_EXECUTOR", None) 

199 os.environ["LSST_S3_USE_THREADS"] = "False" 

200 

201 

202class SupervisorCommunicator: 

203 """A helper object that lets the supervisor direct the other workers. 

204 

205 Parameters 

206 ---------- 

207 log : `lsst.utils.logging.LsstLogAdapter` 

208 LSST-customized logger. 

209 n_scanners : `int` 

210 Number of scanner workers. 

211 worker_factory : `WorkerFactory` 

212 Abstraction over threading vs. multiprocessing. 

213 config : `AggregatorConfig` 

214 Configuration for the aggregator. 

215 """ 

216 

217 def __init__( 

218 self, 

219 log: LsstLogAdapter, 

220 n_scanners: int, 

221 worker_factory: WorkerFactory, 

222 config: AggregatorConfig, 

223 ) -> None: 

224 self.config = config 

225 self.progress = ProgressManager(log, config) 

226 self.n_scanners = n_scanners 

227 # The supervisor sends scan requests to scanners on this queue. 

228 # When complete, the supervisor sends n_scanners sentinals and each 

229 # scanner is careful to only take one before it starts its shutdown. 

230 self._scan_requests: Queue[_ScanRequest | Literal[_Sentinel.NO_MORE_SCAN_REQUESTS]] = ( 

231 worker_factory.make_queue() 

232 ) 

233 # The scanners send ingest requests to the ingester on this queue. Each 

234 # scanner sends one sentinal when it is done, and the ingester is 

235 # careful to wait for n_scanners sentinals to arrive before it starts 

236 # its shutdown. 

237 self._ingest_requests: Queue[IngestRequest | Literal[_Sentinel.NO_MORE_INGEST_REQUESTS]] = ( 

238 worker_factory.make_queue() 

239 ) 

240 # The scanners send write requests to the writer on this queue (which 

241 # will be `None` if we're not writing). The supervisor also sends 

242 # write requests for blocked quanta (which we don't scan). Each 

243 # scanner and the supervisor send one sentinal when done, and the 

244 # writer waits for (n_scanners + 1) sentinals to arrive before it 

245 # starts its shutdown. 

246 self._write_requests: ( 

247 Queue[ProvenanceQuantumScanData | Literal[_Sentinel.NO_MORE_WRITE_REQUESTS]] | None 

248 ) = worker_factory.make_queue() if config.is_writing_provenance else None 

249 # All other workers use this queue to send many different kinds of 

250 # reports the supervisor. The supervisor waits for a _DONE sentinal 

251 # from each worker before it finishes its shutdown. 

252 self._reports: Queue[Report] = worker_factory.make_queue() 

253 # The writer sends the compression dictionary to the scanners on this 

254 # queue. It puts n_scanners copies on the queue, and each scanner only 

255 # takes one. The compression_dict queue has no sentinal because it is 

256 # only used at most once; the supervisor takes responsibility for 

257 # clearing it out shutting down. 

258 self._compression_dict: Queue[_CompressionDictionary] = worker_factory.make_queue() 

259 # The supervisor sets this event when it receives an interrupt request 

260 # from an exception in the main process (usually KeyboardInterrupt). 

261 # Worker communicators check this in their polling loops and raise 

262 # FatalWorkerError when they see it set. 

263 self._cancel_event: Event = worker_factory.make_event() 

264 # Track what state we are in closing down, so we can start at the right 

265 # point if we're interrupted and __exit__ needs to clean up. Note that 

266 # we can't rely on a non-exception __exit__ to do any shutdown work 

267 # that might be slow, since a KeyboardInterrupt that occurs when 

268 # __exit__ is already running can't be caught inside __exit__. 

269 self._sent_no_more_scan_requests = False 

270 self._sent_no_more_write_requests = False 

271 self._n_scanners_done = 0 

272 self.workers: dict[str, Worker] = {} 

273 

274 def _wait_for_workers_to_finish(self, already_failing: bool = False) -> None: 

275 # Orderly shutdown, including exceptions: let workers clear out the 

276 # queues they're responsible for reading from. 

277 if not self._sent_no_more_scan_requests: 

278 for _ in range(self.n_scanners): 

279 self._scan_requests.put(_Sentinel.NO_MORE_SCAN_REQUESTS) 

280 self._sent_no_more_scan_requests = True 

281 if not self._sent_no_more_write_requests and self._write_requests is not None: 

282 self._write_requests.put(_Sentinel.NO_MORE_WRITE_REQUESTS) 

283 self._sent_no_more_write_requests = True 

284 while not all(w.successful for w in self.workers.values()): 

285 match self._handle_progress_reports( 

286 self._get_report(block=True), already_failing=already_failing 

287 ): 

288 case None | ScanReport(): 

289 pass 

290 case _WorkerDone(name=worker_name): 

291 self.workers[worker_name].successful = True 

292 if worker_name == IngesterCommunicator.get_worker_name(): 

293 self.progress.quantum_ingests.close() 

294 elif worker_name == WriterCommunicator.get_worker_name(): 

295 self.progress.writes.close() 

296 else: 

297 self._n_scanners_done += 1 

298 if self._n_scanners_done == self.n_scanners: 

299 self.progress.scans.close() 

300 case unexpected: 

301 raise AssertionError(f"Unexpected message {unexpected!r} to supervisor.") 

302 self.log.verbose( 

303 "Waiting for workers [%s] to report successful completion.", 

304 ", ".join(w.name for w in self.workers.values() if not w.successful), 

305 ) 

306 self.log.verbose("Checking that all queues are empty.") 

307 if self._scan_requests.clear(): 

308 self.progress.log.warning("Scan request queue was not empty at shutdown.") 

309 self._scan_requests.kill() 

310 if self._ingest_requests.clear(): 

311 self.progress.log.warning("Ingest request queue was not empty at shutdown.") 

312 self._ingest_requests.kill() 

313 if self._write_requests is not None and self._write_requests.clear(): 

314 self.progress.log.warning("Write request queue was not empty at shutdown.") 

315 self._write_requests.kill() 

316 if self._reports.clear(): 

317 self.progress.log.warning("Reports queue was not empty at shutdown.") 

318 self._reports.kill() 

319 if self._compression_dict.clear(): 

320 self.progress.log.warning("Compression dictionary queue was not empty at shutdown.") 

321 self._compression_dict.kill() 

322 for worker in self.workers.values(): 

323 self.log.verbose("Waiting for %s to shut down.", worker.name) 

324 worker.join() 

325 

326 def _terminate(self) -> None: 

327 # Disorderly shutdown: we cannot assume any of the 

328 # multiprocessing.Queue object work, and in fact they may hang 

329 # if we try to do anything with them. 

330 self._scan_requests.kill() 

331 self._ingest_requests.kill() 

332 if self._write_requests is not None: 

333 self._write_requests.kill() 

334 self._compression_dict.kill() 

335 self._reports.kill() 

336 for name, worker in self.workers.items(): 

337 if worker.is_alive(): 

338 self.progress.log.critical("Terminating worker %r.", name) 

339 worker.kill() 

340 

341 def __enter__(self) -> Self: 

342 _disable_resources_parallelism() 

343 self.progress.__enter__() 

344 # We make the low-level logger in __enter__ instead of __init__ only 

345 # because that's the pattern used by true workers (where it matters). 

346 self.log = make_worker_log("supervisor", self.config) 

347 return self 

348 

349 def __exit__( 

350 self, 

351 exc_type: type[BaseException] | None, 

352 exc_value: BaseException | None, 

353 traceback: TracebackType | None, 

354 ) -> None: 

355 if exc_type is not None: 

356 self._cancel_event.set() 

357 if exc_type is _WorkerCommunicationError: 

358 self.progress.log.critical("Worker '%s' was terminated before it could finish.", exc_value) 

359 self._terminate() 

360 return None 

361 if exc_type is not FatalWorkerError: 

362 self.progress.log.critical("Caught %s; attempting to shut down cleanly.", exc_type) 

363 try: 

364 self._wait_for_workers_to_finish(already_failing=exc_type is not None) 

365 except _WorkerCommunicationError as err: 

366 self.progress.log.critical( 

367 "Worker '%s' was terminated before it could finish (after scanning).", err 

368 ) 

369 self._terminate() 

370 raise 

371 self.progress.__exit__(exc_type, exc_value, traceback) 

372 return None 

373 

374 def request_scan(self, quantum_id: uuid.UUID) -> None: 

375 """Send a request to the scanners to scan the given quantum. 

376 

377 Parameters 

378 ---------- 

379 quantum_id : `uuid.UUID` 

380 ID of the quantum to scan. 

381 """ 

382 self._scan_requests.put(_ScanRequest(quantum_id)) 

383 

384 def request_write(self, request: ProvenanceQuantumScanData) -> None: 

385 """Send a request to the writer to write provenance for the given scan. 

386 

387 Parameters 

388 ---------- 

389 request : `ProvenanceQuantumScanData` 

390 Information from scanning a quantum (or knowing you don't have to, 

391 in the case of blocked quanta). 

392 """ 

393 assert self._write_requests is not None, "Writer should not be used if writing is disabled." 

394 self._write_requests.put(request) 

395 

396 def poll(self) -> Iterator[ScanReport]: 

397 """Poll for reports from workers while sending scan requests. 

398 

399 Yields 

400 ------ 

401 scan_report : `ScanReport` 

402 A report from a scanner that a quantum was scanned. 

403 

404 Notes 

405 ----- 

406 This iterator blocks until the first scan report is received, and then 

407 it continues until the report queue is empty. 

408 """ 

409 block = True 

410 while report := self._get_report(block=block): 

411 match self._handle_progress_reports(report): 

412 case ScanReport() as scan_report: 

413 block = False 

414 yield scan_report 

415 case None: 

416 pass 

417 case unexpected: 

418 raise AssertionError(f"Unexpected message {unexpected!r} to supervisor.") 

419 

420 @overload 

421 def _get_report(self, block: Literal[True]) -> Report: ... 421 ↛ exitline 421 didn't return from function '_get_report' because

422 

423 @overload 

424 def _get_report(self, block: bool) -> Report | None: ... 424 ↛ exitline 424 didn't return from function '_get_report' because

425 

426 def _get_report(self, block: bool) -> Report | None: 

427 """Get a report from the reports queue, with timeout guards on 

428 blocking requests. 

429 

430 This method may *return* WorkerCommunicatorError (rather than raise it) 

431 when a serious error occurred communicating with a subprocess. This 

432 is to avoid raising an exception in an __exit__ method (which calls 

433 method). 

434 """ 

435 report = self._reports.get(block=block, timeout=self.config.worker_check_timeout) 

436 while report is None and block: 

437 # We hit the timeout; make sure all of the workers 

438 # that should be alive actually are. 

439 for name, worker in self.workers.items(): 

440 if not worker.successful and not worker.is_alive(): 

441 # Delete this worker from the list of workers so we don't 

442 # hit this condition again when we try to handle the 

443 # exception we raise. 

444 raise _WorkerCommunicationError(name) 

445 # If nothing is dead and we didn't hit the hang timeout, keep 

446 # trying. 

447 report = self._reports.get(block=block, timeout=self.config.worker_check_timeout) 

448 return report 

449 

450 def _handle_progress_reports( 

451 self, report: Report, already_failing: bool = False 

452 ) -> ScanReport | _WorkerDone | None: 

453 """Handle reports to the supervisor that can appear at any time, and 

454 are typically just updates to the progress we've made. 

455 

456 This includes: 

457 

458 - exceptions from workers (which raise `FatalWorkerError` here to 

459 trigger ``__exit__``); 

460 - ingest reports; 

461 - write reports; 

462 - progress logs. 

463 

464 If one of these is handled, `None` is returned; otherwise the original 

465 report is returned. 

466 """ 

467 match report: 

468 case _WorkerErrorMessage(traceback=traceback, worker=worker): 

469 self.progress.log.fatal("Exception raised on %s: \n%s", worker, traceback) 

470 if not already_failing: 

471 raise FatalWorkerError() 

472 case _IngestReport(n_producers=n_producers): 

473 self.progress.quantum_ingests.update(n_producers) 

474 case _Sentinel.WRITE_REPORT: 

475 self.progress.writes.update(1) 

476 case _ProgressLog(message=message, level=level): 

477 self.progress.log.log(level, "%s [after %0.1fs]", message, self.progress.elapsed_time) 

478 case _: 

479 return report 

480 return None 

481 

482 

483class WorkerCommunicator: 

484 """A base class for non-supervisor worker communicators. 

485 

486 Parameters 

487 ---------- 

488 supervisor : `SupervisorCommunicator` 

489 Communicator for the supervisor to grab queues and information from. 

490 name : `str` 

491 Human-readable name for this worker. 

492 

493 Notes 

494 ----- 

495 Each worker communicator is constructed in the main process and entered as 

496 a context manager *only* on the actual worker process, so attributes that 

497 cannot be pickled are constructed in ``__enter__`` instead of ``__init__``. 

498 

499 Worker communicators provide access to an `AggregatorConfig` and a logger 

500 to their workers. As context managers, they handle exceptions and ensure 

501 clean shutdowns, and since most workers need to use a lot of other context 

502 managers (for file reading and writing, mostly), they provide an `enter` 

503 method to keep every worker from also having to be a context manager just 

504 to hold a context manager instance attribute. 

505 

506 Worker communicators can also be configured to record and dump profiling 

507 information. 

508 """ 

509 

510 def __init__(self, supervisor: SupervisorCommunicator, name: str): 

511 self.name = name 

512 self.config = supervisor.config 

513 self._reports = supervisor._reports 

514 self._cancel_event = supervisor._cancel_event 

515 

516 def __enter__(self) -> Self: 

517 _disable_resources_parallelism() 

518 self.log = make_worker_log(self.name, self.config) 

519 self.log.verbose("%s has PID %s (parent is %s).", self.name, os.getpid(), os.getppid()) 

520 self._exit_stack = ExitStack().__enter__() 

521 if self.config.n_processes > 1: 

522 # Multiprocessing: ignore interrupts so we can shut down cleanly. 

523 signal.signal(signal.SIGINT, signal.SIG_IGN) 

524 if self.config.worker_profile_dir is not None: 

525 # We use time.time because we're interested in wall-clock time, 

526 # not just CPU effort, since this is I/O-bound work. 

527 self._profiler = cProfile.Profile(timer=time.time) 

528 self._profiler.enable() 

529 return self 

530 

531 def __exit__( 

532 self, 

533 exc_type: type[BaseException] | None, 

534 exc_value: BaseException | None, 

535 traceback: TracebackType | None, 

536 ) -> bool | None: 

537 if self.config.worker_profile_dir is not None and self.config.n_processes > 1: 

538 self._profiler.disable() 

539 os.makedirs(self.config.worker_profile_dir, exist_ok=True) 

540 self._profiler.dump_stats(os.path.join(self.config.worker_profile_dir, f"{self.name}.profile")) 

541 if exc_value is not None: 

542 assert exc_type is not None, "Should be guaranteed by Python, but MyPy doesn't know that." 

543 if exc_type is not FatalWorkerError: 

544 self.log.warning("Error raised on this worker.", exc_info=(exc_type, exc_value, traceback)) 

545 assert exc_type is not None and traceback is not None 

546 self._reports.put( 

547 _WorkerErrorMessage( 

548 self.name, 

549 "".join(format_exception(exc_type, exc_value, traceback)), 

550 ) 

551 ) 

552 self.log.debug("Error message sent to supervisor.") 

553 else: 

554 self.log.warning("Shutting down due to exception raised on another worker.") 

555 self._exit_stack.__exit__(exc_type, exc_value, traceback) 

556 return True 

557 

558 @property 

559 def exit_stack(self) -> ExitStack: 

560 """A `contextlib.ExitStack` tied to the communicator.""" 

561 return self._exit_stack 

562 

563 def log_progress(self, level: int, message: str) -> None: 

564 """Send a high-level log message to the supervisor. 

565 

566 Parameters 

567 ---------- 

568 level : `int` 

569 Log level. Should be ``VERBOSE`` or higher. 

570 message : `str` 

571 Log message. 

572 """ 

573 self._reports.put(_ProgressLog(message=message, level=level)) 

574 

575 def check_for_cancel(self) -> None: 

576 """Check for a cancel signal from the supervisor and raise 

577 `FatalWorkerError` if it is present. 

578 """ 

579 if self._cancel_event.is_set(): 

580 raise FatalWorkerError() 

581 

582 

583class ScannerCommunicator(WorkerCommunicator): 

584 """A communicator for scanner workers. 

585 

586 Parameters 

587 ---------- 

588 supervisor : `SupervisorCommunicator` 

589 Communicator for the supervisor to grab queues and information from. 

590 scanner_id : `int` 

591 Integer ID for this canner. 

592 """ 

593 

594 def __init__(self, supervisor: SupervisorCommunicator, scanner_id: int): 

595 super().__init__(supervisor, self.get_worker_name(scanner_id)) 

596 self.scanner_id = scanner_id 

597 self._scan_requests = supervisor._scan_requests 

598 self._ingest_requests = supervisor._ingest_requests 

599 self._write_requests = supervisor._write_requests 

600 self._compression_dict = supervisor._compression_dict 

601 self._got_no_more_scan_requests: bool = False 

602 self._sent_no_more_ingest_requests: bool = False 

603 

604 @staticmethod 

605 def get_worker_name(scanner_id: int) -> str: 

606 return f"scanner-{scanner_id:03d}" 

607 

608 def report_scan(self, msg: ScanReport) -> None: 

609 """Report a completed scan to the supervisor. 

610 

611 Parameters 

612 ---------- 

613 msg : `ScanReport` 

614 Report to send. 

615 """ 

616 self._reports.put(msg) 

617 

618 def request_ingest(self, request: IngestRequest) -> None: 

619 """Ask the ingester to ingest a quantum's outputs. 

620 

621 Parameters 

622 ---------- 

623 request : `IngestRequest` 

624 Description of the datasets to ingest. 

625 

626 Notes 

627 ----- 

628 If this request has no datasets, this automatically reports the ingest 

629 as complete to the supervisor instead of sending it to the ingester. 

630 """ 

631 if request: 

632 self._ingest_requests.put(request) 

633 else: 

634 self._reports.put(_IngestReport(1)) 

635 

636 def request_write(self, request: ProvenanceQuantumScanData) -> None: 

637 """Ask the writer to write provenance for a quantum. 

638 

639 Parameters 

640 ---------- 

641 request : `ProvenanceQuantumScanData` 

642 Result of scanning a quantum. 

643 """ 

644 assert self._write_requests is not None, "Writer should not be used if writing is disabled." 

645 self._write_requests.put(request) 

646 

647 def get_compression_dict(self) -> bytes | None: 

648 """Attempt to get the compression dict from the writer. 

649 

650 Returns 

651 ------- 

652 data : `bytes` or `None` 

653 The `bytes` representation of the compression dictionary, or `None` 

654 if the compression dictionary is not yet available. 

655 

656 Notes 

657 ----- 

658 A scanner should only call this method before it actually has the 

659 compression dict. 

660 """ 

661 if (cdict := self._compression_dict.get()) is not None: 

662 return cdict.data 

663 return None 

664 

665 def poll(self) -> Iterator[uuid.UUID]: 

666 """Poll for scan requests to process. 

667 

668 Yields 

669 ------ 

670 quantum_id : `uuid.UUID` 

671 ID of a new quantum to scan. 

672 

673 Notes 

674 ----- 

675 This iterator ends when the supervisor reports that it is done 

676 traversing the graph. 

677 """ 

678 while True: 

679 self.check_for_cancel() 

680 scan_request = self._scan_requests.get(block=True, timeout=self.config.worker_sleep) 

681 if scan_request is _Sentinel.NO_MORE_SCAN_REQUESTS: 

682 self._got_no_more_scan_requests = True 

683 return 

684 if scan_request is not None: 

685 yield scan_request.quantum_id 

686 

687 def __exit__( 

688 self, 

689 exc_type: type[BaseException] | None, 

690 exc_value: BaseException | None, 

691 traceback: TracebackType | None, 

692 ) -> bool | None: 

693 result = super().__exit__(exc_type, exc_value, traceback) 

694 self._ingest_requests.put(_Sentinel.NO_MORE_INGEST_REQUESTS) 

695 if self._write_requests is not None: 

696 self._write_requests.put(_Sentinel.NO_MORE_WRITE_REQUESTS) 

697 while not self._got_no_more_scan_requests: 

698 if ( 

699 not self._got_no_more_scan_requests 

700 and self._scan_requests.get(block=True) is _Sentinel.NO_MORE_SCAN_REQUESTS 

701 ): 

702 self._got_no_more_scan_requests = True 

703 # We let the writer clear out the compression dict queue. 

704 self.log.verbose("Sending completion message.") 

705 self._reports.put(_WorkerDone(self.name)) 

706 return result 

707 

708 

709class IngesterCommunicator(WorkerCommunicator): 

710 """A communicator for the ingester worker. 

711 

712 Parameters 

713 ---------- 

714 supervisor : `SupervisorCommunicator` 

715 Communicator for the supervisor to grab queues and information from. 

716 """ 

717 

718 def __init__(self, supervisor: SupervisorCommunicator): 

719 super().__init__(supervisor, self.get_worker_name()) 

720 self.n_scanners = supervisor.n_scanners 

721 self._ingest_requests = supervisor._ingest_requests 

722 self._n_requesters_done = 0 

723 

724 @staticmethod 

725 def get_worker_name() -> str: 

726 return "ingester" 

727 

728 def __exit__( 

729 self, 

730 exc_type: type[BaseException] | None, 

731 exc_value: BaseException | None, 

732 traceback: TracebackType | None, 

733 ) -> bool | None: 

734 result = super().__exit__(exc_type, exc_value, traceback) 

735 while self._n_requesters_done != self.n_scanners: 

736 self.log.debug( 

737 "Waiting for %d requesters to be done (currently %d).", 

738 self.n_scanners, 

739 self._n_requesters_done, 

740 ) 

741 if self._ingest_requests.get(block=True) is _Sentinel.NO_MORE_INGEST_REQUESTS: 

742 self._n_requesters_done += 1 

743 self.log.verbose("Sending completion message.") 

744 self._reports.put(_WorkerDone(self.name)) 

745 return result 

746 

747 def report_ingest(self, n_producers: int) -> None: 

748 """Report to the supervisor that an ingest batch was completed. 

749 

750 Parameters 

751 ---------- 

752 n_producers : `int` 

753 Number of producing quanta whose datasets were ingested. 

754 """ 

755 self._reports.put(_IngestReport(n_producers)) 

756 

757 def poll(self) -> Iterator[IngestRequest]: 

758 """Poll for ingest requests from the scanner workers. 

759 

760 Yields 

761 ------ 

762 request : `IngestRequest` 

763 A request to ingest datasets produced by a single quantum. 

764 

765 Notes 

766 ----- 

767 This iterator ends when all scanners indicate that they are done making 

768 ingest requests. 

769 """ 

770 while True: 

771 self.check_for_cancel() 

772 ingest_request = self._ingest_requests.get(block=True, timeout=_TINY_TIMEOUT) 

773 if ingest_request is _Sentinel.NO_MORE_INGEST_REQUESTS: 

774 self._n_requesters_done += 1 

775 if self._n_requesters_done == self.n_scanners: 

776 return 

777 else: 

778 continue 

779 if ingest_request is not None: 

780 yield ingest_request 

781 

782 

783class WriterCommunicator(WorkerCommunicator): 

784 """A communicator for the writer worker. 

785 

786 Parameters 

787 ---------- 

788 supervisor : `SupervisorCommunicator` 

789 Communicator for the supervisor to grab queues and information from. 

790 """ 

791 

792 def __init__(self, supervisor: SupervisorCommunicator): 

793 assert supervisor._write_requests is not None 

794 super().__init__(supervisor, self.get_worker_name()) 

795 self.n_scanners = supervisor.n_scanners 

796 self._write_requests = supervisor._write_requests 

797 self._compression_dict = supervisor._compression_dict 

798 self._n_requesters = supervisor.n_scanners + 1 

799 self._n_requesters_done = 0 

800 self._sent_compression_dict = False 

801 

802 @staticmethod 

803 def get_worker_name() -> str: 

804 return "writer" 

805 

806 def __exit__( 

807 self, 

808 exc_type: type[BaseException] | None, 

809 exc_value: BaseException | None, 

810 traceback: TracebackType | None, 

811 ) -> bool | None: 

812 result = super().__exit__(exc_type, exc_value, traceback) 

813 if exc_type is None: 

814 self.log_progress(logging.INFO, "Provenance quantum graph written successfully.") 

815 while self._n_requesters_done != self._n_requesters: 

816 self.log.debug( 

817 "Waiting for %d requesters to be done (currently %d).", 

818 self._n_requesters, 

819 self._n_requesters_done, 

820 ) 

821 if self._write_requests.get(block=True) is _Sentinel.NO_MORE_WRITE_REQUESTS: 

822 self._n_requesters_done += 1 

823 if self._compression_dict.clear(): 

824 self.log.verbose("Cleared out compression dictionary queue.") 

825 else: 

826 self.log.verbose("Compression dictionary queue was already empty.") 

827 self.log.verbose("Sending completion message.") 

828 self._reports.put(_WorkerDone(self.name)) 

829 return result 

830 

831 def poll(self) -> Iterator[ProvenanceQuantumScanData]: 

832 """Poll for writer requests from the scanner workers and supervisor. 

833 

834 Yields 

835 ------ 

836 request : `ProvenanceQuantumScanData` 

837 The result of a quantum scan. 

838 

839 Notes 

840 ----- 

841 This iterator ends when all scanners and the supervisor indicate that 

842 they are done making write requests. 

843 """ 

844 while True: 

845 self.check_for_cancel() 

846 write_request = self._write_requests.get(block=True, timeout=_TINY_TIMEOUT) 

847 if write_request is _Sentinel.NO_MORE_WRITE_REQUESTS: 

848 self._n_requesters_done += 1 

849 if self._n_requesters_done == self._n_requesters: 

850 return 

851 else: 

852 continue 

853 if write_request is not None: 

854 yield write_request 

855 

856 def send_compression_dict(self, cdict_data: bytes) -> None: 

857 """Send the compression dictionary to the scanners. 

858 

859 Parameters 

860 ---------- 

861 cdict_data : `bytes` 

862 The `bytes` representation of the compression dictionary. 

863 """ 

864 self.log.debug("Sending compression dictionary.") 

865 for _ in range(self.n_scanners): 

866 self._compression_dict.put(_CompressionDictionary(cdict_data)) 

867 self._sent_compression_dict = True 

868 

869 def report_write(self) -> None: 

870 """Report to the supervisor that provenance for a quantum was written 

871 to the graph. 

872 """ 

873 self._reports.put(_Sentinel.WRITE_REPORT) 

874 

875 def periodically_check_for_cancel[T](self, iterable: Iterable[T], n: int = 100) -> Iterator[T]: 

876 """Iterate while checking for a cancellation signal every ``n`` 

877 iterations. 

878 

879 Parameters 

880 ---------- 

881 iterable : `~collections.abc.Iterable` 

882 Object to iterate over. 

883 n : `int` 

884 Check for cancellation every ``n`` iterations. 

885 

886 Returns 

887 ------- 

888 iterator : `~collections.abc.Iterator` 

889 Iterator. 

890 """ 

891 i = 0 

892 for entry in iterable: 

893 yield entry 

894 i += 1 

895 if i % n == 0: 

896 self.check_for_cancel()