Coverage for python / lsst / pipe / base / quantum_graph / aggregator / _writer.py: 24%

84 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-06 08:32 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ("Writer",) 

31 

32import dataclasses 

33 

34import zstandard 

35 

36from ...log_on_close import LogOnClose 

37from ...pipeline_graph import TaskImportMode 

38from .._predicted import PredictedQuantumGraphComponents, PredictedQuantumGraphReader 

39from .._provenance import ProvenanceQuantumGraphWriter, ProvenanceQuantumScanData 

40from ._communicators import WriterCommunicator 

41 

42 

43@dataclasses.dataclass 

44class Writer: 

45 """A helper class for the provenance aggregator actually writes the 

46 provenance quantum graph file. 

47 """ 

48 

49 predicted_path: str 

50 """Path to the predicted quantum graph.""" 

51 

52 comms: WriterCommunicator 

53 """Communicator object for this worker.""" 

54 

55 predicted: PredictedQuantumGraphComponents = dataclasses.field(init=False) 

56 """Components of the predicted quantum graph.""" 

57 

58 pending_compression_training: list[ProvenanceQuantumScanData] = dataclasses.field(default_factory=list) 

59 """Unprocessed quantum scans that are being accumulated in order to 

60 build a compression dictionary. 

61 """ 

62 

63 def __post_init__(self) -> None: 

64 assert self.comms.config.is_writing_provenance, "Writer should not be used if writing is disabled." 

65 self.comms.log.info("Reading predicted quantum graph.") 

66 with PredictedQuantumGraphReader.open( 

67 self.predicted_path, import_mode=TaskImportMode.DO_NOT_IMPORT 

68 ) as reader: 

69 self.comms.check_for_cancel() 

70 reader.read_init_quanta() 

71 self.comms.check_for_cancel() 

72 reader.read_quantum_datasets() 

73 self.predicted = reader.components 

74 

75 @staticmethod 

76 def run(predicted_path: str, comms: WriterCommunicator) -> None: 

77 """Run the writer. 

78 

79 Parameters 

80 ---------- 

81 predicted_path : `str` 

82 Path to the predicted quantum graph. 

83 comms : `WriterCommunicator` 

84 Communicator for the writer. 

85 

86 Notes 

87 ----- 

88 This method is designed to run as the ``target`` in 

89 `WorkerFactory.make_worker`. 

90 """ 

91 with comms: 

92 writer = Writer(predicted_path, comms) 

93 writer.loop() 

94 

95 def loop(self) -> None: 

96 """Run the main loop for the writer.""" 

97 qg_writer: ProvenanceQuantumGraphWriter | None = None 

98 if not self.comms.config.zstd_dict_size: 

99 qg_writer = self.make_qg_writer() 

100 self.comms.log.info("Polling for write requests from scanners.") 

101 for request in self.comms.poll(): 

102 if qg_writer is None: 

103 self.pending_compression_training.append(request) 

104 if len(self.pending_compression_training) >= self.comms.config.zstd_dict_n_inputs: 

105 qg_writer = self.make_qg_writer() 

106 else: 

107 qg_writer.write_scan_data(request) 

108 self.comms.report_write() 

109 if qg_writer is None: 

110 qg_writer = self.make_qg_writer() 

111 self.comms.log.info("Writing init outputs.") 

112 qg_writer.write_init_outputs(assume_existence=False) 

113 

114 def make_qg_writer(self) -> ProvenanceQuantumGraphWriter: 

115 """Make a compression dictionary, open the low-level writers, and 

116 write any accumulated scans that were needed to make the compression 

117 dictionary. 

118 

119 Returns 

120 ------- 

121 qg_writer : `ProvenanceQuantumGraphWriter` 

122 Low-level writers struct. 

123 """ 

124 cdict = self.make_compression_dictionary() 

125 self.comms.send_compression_dict(cdict.as_bytes()) 

126 assert self.comms.config.is_writing_provenance and self.comms.config.output_path is not None 

127 self.comms.log.info("Opening output files and processing predicted graph.") 

128 qg_writer = ProvenanceQuantumGraphWriter( 

129 self.comms.config.output_path, 

130 exit_stack=self.comms.exit_stack, 

131 log_on_close=LogOnClose(self.comms.log_progress), 

132 predicted=self.predicted, 

133 zstd_level=self.comms.config.zstd_level, 

134 cdict_data=cdict.as_bytes(), 

135 loop_wrapper=self.comms.periodically_check_for_cancel, 

136 log=self.comms.log, 

137 ) 

138 self.comms.check_for_cancel() 

139 self.comms.log.info("Compressing and writing queued scan requests.") 

140 for request in self.pending_compression_training: 

141 qg_writer.write_scan_data(request) 

142 self.comms.report_write() 

143 del self.pending_compression_training 

144 self.comms.check_for_cancel() 

145 self.comms.log.info("Writing overall inputs.") 

146 qg_writer.write_overall_inputs(self.comms.periodically_check_for_cancel) 

147 qg_writer.write_packages() 

148 self.comms.log.info("Returning to write request loop.") 

149 return qg_writer 

150 

151 def make_compression_dictionary(self) -> zstandard.ZstdCompressionDict: 

152 """Make the compression dictionary. 

153 

154 Returns 

155 ------- 

156 cdict : `zstandard.ZstdCompressionDict` 

157 The compression dictionary. 

158 """ 

159 if ( 

160 not self.comms.config.zstd_dict_size 

161 or len(self.pending_compression_training) < self.comms.config.zstd_dict_n_inputs 

162 ): 

163 self.comms.log.info("Making compressor with no dictionary.") 

164 return zstandard.ZstdCompressionDict(b"") 

165 self.comms.log.info("Training compression dictionary.") 

166 training_inputs: list[bytes] = [] 

167 # We start the dictionary training with *predicted* quantum dataset 

168 # models, since those have almost all of the same attributes as the 

169 # provenance quantum and dataset models, and we can get a nice random 

170 # sample from just the first N, since they're ordered by UUID. We 

171 # chop out the datastore records since those don't appear in the 

172 # provenance graph. 

173 for predicted_quantum in self.predicted.quantum_datasets.values(): 

174 if len(training_inputs) == self.comms.config.zstd_dict_n_inputs: 

175 break 

176 predicted_quantum.datastore_records.clear() 

177 training_inputs.append(predicted_quantum.model_dump_json().encode()) 

178 # Add the provenance quanta, metadata, and logs we've accumulated. 

179 for write_request in self.pending_compression_training: 

180 assert not write_request.is_compressed, "We can't compress without the compression dictionary." 

181 training_inputs.append(write_request.quantum) 

182 training_inputs.append(write_request.metadata) 

183 training_inputs.append(write_request.logs) 

184 return zstandard.train_dictionary(self.comms.config.zstd_dict_size, training_inputs)