Coverage for python / lsst / pipe / base / quantum_graph / aggregator / _writer.py: 24%
84 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 08:20 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 08:20 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ("Writer",)
32import dataclasses
34import zstandard
36from ...log_on_close import LogOnClose
37from ...pipeline_graph import TaskImportMode
38from .._predicted import PredictedQuantumGraphComponents, PredictedQuantumGraphReader
39from .._provenance import ProvenanceQuantumGraphWriter, ProvenanceQuantumScanData
40from ._communicators import WriterCommunicator
43@dataclasses.dataclass
44class Writer:
45 """A helper class for the provenance aggregator actually writes the
46 provenance quantum graph file.
47 """
49 predicted_path: str
50 """Path to the predicted quantum graph."""
52 comms: WriterCommunicator
53 """Communicator object for this worker."""
55 predicted: PredictedQuantumGraphComponents = dataclasses.field(init=False)
56 """Components of the predicted quantum graph."""
58 pending_compression_training: list[ProvenanceQuantumScanData] = dataclasses.field(default_factory=list)
59 """Unprocessed quantum scans that are being accumulated in order to
60 build a compression dictionary.
61 """
63 def __post_init__(self) -> None:
64 assert self.comms.config.is_writing_provenance, "Writer should not be used if writing is disabled."
65 self.comms.log.info("Reading predicted quantum graph.")
66 with PredictedQuantumGraphReader.open(
67 self.predicted_path, import_mode=TaskImportMode.DO_NOT_IMPORT
68 ) as reader:
69 self.comms.check_for_cancel()
70 reader.read_init_quanta()
71 self.comms.check_for_cancel()
72 reader.read_quantum_datasets()
73 self.predicted = reader.components
75 @staticmethod
76 def run(predicted_path: str, comms: WriterCommunicator) -> None:
77 """Run the writer.
79 Parameters
80 ----------
81 predicted_path : `str`
82 Path to the predicted quantum graph.
83 comms : `WriterCommunicator`
84 Communicator for the writer.
86 Notes
87 -----
88 This method is designed to run as the ``target`` in
89 `WorkerFactory.make_worker`.
90 """
91 with comms:
92 writer = Writer(predicted_path, comms)
93 writer.loop()
95 def loop(self) -> None:
96 """Run the main loop for the writer."""
97 qg_writer: ProvenanceQuantumGraphWriter | None = None
98 if not self.comms.config.zstd_dict_size:
99 qg_writer = self.make_qg_writer()
100 self.comms.log.info("Polling for write requests from scanners.")
101 for request in self.comms.poll():
102 if qg_writer is None:
103 self.pending_compression_training.append(request)
104 if len(self.pending_compression_training) >= self.comms.config.zstd_dict_n_inputs:
105 qg_writer = self.make_qg_writer()
106 else:
107 qg_writer.write_scan_data(request)
108 self.comms.report_write()
109 if qg_writer is None:
110 qg_writer = self.make_qg_writer()
111 self.comms.log.info("Writing init outputs.")
112 qg_writer.write_init_outputs(assume_existence=False)
114 def make_qg_writer(self) -> ProvenanceQuantumGraphWriter:
115 """Make a compression dictionary, open the low-level writers, and
116 write any accumulated scans that were needed to make the compression
117 dictionary.
119 Returns
120 -------
121 qg_writer : `ProvenanceQuantumGraphWriter`
122 Low-level writers struct.
123 """
124 cdict = self.make_compression_dictionary()
125 self.comms.send_compression_dict(cdict.as_bytes())
126 assert self.comms.config.is_writing_provenance and self.comms.config.output_path is not None
127 self.comms.log.info("Opening output files and processing predicted graph.")
128 qg_writer = ProvenanceQuantumGraphWriter(
129 self.comms.config.output_path,
130 exit_stack=self.comms.exit_stack,
131 log_on_close=LogOnClose(self.comms.log_progress),
132 predicted=self.predicted,
133 zstd_level=self.comms.config.zstd_level,
134 cdict_data=cdict.as_bytes(),
135 loop_wrapper=self.comms.periodically_check_for_cancel,
136 log=self.comms.log,
137 )
138 self.comms.check_for_cancel()
139 self.comms.log.info("Compressing and writing queued scan requests.")
140 for request in self.pending_compression_training:
141 qg_writer.write_scan_data(request)
142 self.comms.report_write()
143 del self.pending_compression_training
144 self.comms.check_for_cancel()
145 self.comms.log.info("Writing overall inputs.")
146 qg_writer.write_overall_inputs(self.comms.periodically_check_for_cancel)
147 qg_writer.write_packages()
148 self.comms.log.info("Returning to write request loop.")
149 return qg_writer
151 def make_compression_dictionary(self) -> zstandard.ZstdCompressionDict:
152 """Make the compression dictionary.
154 Returns
155 -------
156 cdict : `zstandard.ZstdCompressionDict`
157 The compression dictionary.
158 """
159 if (
160 not self.comms.config.zstd_dict_size
161 or len(self.pending_compression_training) < self.comms.config.zstd_dict_n_inputs
162 ):
163 self.comms.log.info("Making compressor with no dictionary.")
164 return zstandard.ZstdCompressionDict(b"")
165 self.comms.log.info("Training compression dictionary.")
166 training_inputs: list[bytes] = []
167 # We start the dictionary training with *predicted* quantum dataset
168 # models, since those have almost all of the same attributes as the
169 # provenance quantum and dataset models, and we can get a nice random
170 # sample from just the first N, since they're ordered by UUID. We
171 # chop out the datastore records since those don't appear in the
172 # provenance graph.
173 for predicted_quantum in self.predicted.quantum_datasets.values():
174 if len(training_inputs) == self.comms.config.zstd_dict_n_inputs:
175 break
176 predicted_quantum.datastore_records.clear()
177 training_inputs.append(predicted_quantum.model_dump_json().encode())
178 # Add the provenance quanta, metadata, and logs we've accumulated.
179 for write_request in self.pending_compression_training:
180 assert not write_request.is_compressed, "We can't compress without the compression dictionary."
181 training_inputs.append(write_request.quantum)
182 training_inputs.append(write_request.metadata)
183 training_inputs.append(write_request.logs)
184 return zstandard.train_dictionary(self.comms.config.zstd_dict_size, training_inputs)