Coverage for python / lsst / pipe / base / graph / _loadHelpers.py: 26%

87 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-18 08:44 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ("LoadHelper",) 

30 

31import struct 

32from collections.abc import Iterable 

33from contextlib import AbstractContextManager, ExitStack 

34from dataclasses import dataclass 

35from io import BufferedRandom, BytesIO 

36from types import TracebackType 

37from typing import TYPE_CHECKING, BinaryIO 

38from uuid import UUID 

39 

40from lsst.daf.butler import DimensionUniverse 

41from lsst.daf.butler.persistence_context import PersistenceContextVars 

42from lsst.resources import ResourceHandleProtocol, ResourcePath 

43 

44if TYPE_CHECKING: 

45 from ._versionDeserializers import DeserializerBase 

46 from .graph import QuantumGraph 

47 

48 

49@dataclass 

50class LoadHelper(AbstractContextManager["LoadHelper"]): 

51 """Helper class to assist with selecting the appropriate loader 

52 and managing any contexts that may be needed. 

53 

54 This helper will raise a `ValueError` if the specified file does not appear 

55 to be a valid `QuantumGraph` save file. 

56 """ 

57 

58 uri: ResourcePath | BinaryIO 

59 """ResourcePath object from which the `QuantumGraph` is to be loaded 

60 """ 

61 minimumVersion: int 

62 """ 

63 Minimum version of a save file to load. Set to -1 to load all 

64 versions. Older versions may need to be loaded, and re-saved 

65 to upgrade them to the latest format before they can be used in 

66 production. 

67 """ 

68 fullRead: bool = False 

69 

70 def __post_init__(self) -> None: 

71 self._resourceHandle: ResourceHandleProtocol | None = None 

72 self._exitStack = ExitStack() 

73 

74 def _initialize(self) -> None: 

75 # need to import here to avoid cyclic imports 

76 from ._versionDeserializers import DESERIALIZER_MAP 

77 from .graph import MAGIC_BYTES, STRUCT_FMT_BASE 

78 

79 # Read the first few bytes which correspond to the magic identifier 

80 # bytes, and save version 

81 magicSize = len(MAGIC_BYTES) 

82 # read in just the fmt base to determine the save version 

83 fmtSize = struct.calcsize(STRUCT_FMT_BASE) 

84 preambleSize = magicSize + fmtSize 

85 headerBytes = self._readBytes(0, preambleSize) 

86 magic = headerBytes[:magicSize] 

87 versionBytes = headerBytes[magicSize:] 

88 

89 save_version = self._validateSave(magic, versionBytes) 

90 

91 # select the appropriate deserializer for this save version 

92 deserializerClass = DESERIALIZER_MAP[save_version] 

93 

94 # read in the bytes corresponding to the mappings and initialize the 

95 # deserializer. This will be the bytes that describe the following 

96 # byte boundaries of the header info 

97 sizeBytes = self._readBytes(preambleSize, preambleSize + deserializerClass.structSize) 

98 # DeserializerBase subclasses are required to have the same constructor 

99 # signature as the base class itself, but there is no way to express 

100 # this in the type system, so we just tell MyPy to ignore it. 

101 self.deserializer: DeserializerBase = deserializerClass(preambleSize, sizeBytes) 

102 # get the header byte range for later reading 

103 self.headerBytesRange = (preambleSize + deserializerClass.structSize, self.deserializer.headerSize) 

104 

105 def _validateSave(self, magic: bytes, versionBytes: bytes) -> int: 

106 """Implement validation on input file, prior to attempting to load it 

107 

108 Parameters 

109 ---------- 

110 magic : `bytes` 

111 The first few bytes of the file, used to verify it is a 

112 `QuantumGraph` save file. 

113 versionBytes : `bytes` 

114 The next few bytes from the beginning of the file, used to parse 

115 which version of the `QuantumGraph` file the save corresponds to. 

116 

117 Returns 

118 ------- 

119 save_version : `int` 

120 The save version parsed from the supplied bytes. 

121 

122 Raises 

123 ------ 

124 ValueError 

125 Raised if the specified file contains the wrong file signature and 

126 is not a `QuantumGraph` save, or if the graph save version is 

127 below the minimum specified version. 

128 """ 

129 from .graph import MAGIC_BYTES, SAVE_VERSION, STRUCT_FMT_BASE 

130 

131 if magic != MAGIC_BYTES: 

132 raise ValueError( 

133 "This file does not appear to be a quantum graph save got magic bytes " 

134 f"{magic!r}, expected {MAGIC_BYTES!r}" 

135 ) 

136 

137 # unpack the save version bytes and verify it is a version that this 

138 # code can understand 

139 (save_version,) = struct.unpack(STRUCT_FMT_BASE, versionBytes) 

140 # loads can sometimes trigger upgrades in format to a latest version, 

141 # in which case accessory code might not match the upgraded graph. 

142 # I.E. switching from old node number to UUID. This clause necessitates 

143 # that users specifically interact with older graph versions and verify 

144 # everything happens appropriately. 

145 if save_version < self.minimumVersion: 

146 raise ValueError( 

147 f"The loaded QuantumGraph is version {save_version}, and the minimum " 

148 f"version specified is {self.minimumVersion}. Please re-run this method " 

149 "with a lower minimum version, then re-save the graph to automatically upgrade" 

150 "to the newest version. Older versions may not work correctly with newer code" 

151 ) 

152 

153 if save_version > SAVE_VERSION: 

154 raise ValueError( 

155 f"The version of this save file is {save_version}, but this version of" 

156 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}" 

157 ) 

158 return save_version 

159 

160 def load( 

161 self, 

162 universe: DimensionUniverse | None = None, 

163 nodes: Iterable[UUID | str] | None = None, 

164 graphID: str | None = None, 

165 ) -> QuantumGraph: 

166 """Load in the specified nodes from the graph. 

167 

168 Load in the `QuantumGraph` containing only the nodes specified in the 

169 ``nodes`` parameter from the graph specified at object creation. If 

170 ``nodes`` is None (the default) the whole graph is loaded. 

171 

172 Parameters 

173 ---------- 

174 universe : `~lsst.daf.butler.DimensionUniverse` or None 

175 DimensionUniverse instance, not used by the method itself but 

176 needed to ensure that registry data structures are initialized. 

177 The universe saved with the graph is used, but if one is passed 

178 it will be used to validate the compatibility with the loaded 

179 graph universe. 

180 nodes : `~collections.abc.Iterable` of `UUID` or `str`; or `None` 

181 The nodes to load from the graph, loads all if value is None 

182 (the default). 

183 graphID : `str` or `None` 

184 If specified this ID is verified against the loaded graph prior to 

185 loading any Nodes. This defaults to None in which case no 

186 validation is done. 

187 

188 Returns 

189 ------- 

190 graph : `QuantumGraph` 

191 The loaded `QuantumGraph` object. 

192 

193 Raises 

194 ------ 

195 ValueError 

196 Raised if one or more of the nodes requested is not in the 

197 `QuantumGraph` or if graphID parameter does not match the graph 

198 being loaded. 

199 RuntimeError 

200 Raised if supplied `~lsst.daf.butler.DimensionUniverse` is not 

201 compatible with the `~lsst.daf.butler.DimensionUniverse` saved in 

202 the graph. Raised if the method was not called from within a 

203 context block. 

204 """ 

205 if self._resourceHandle is None: 

206 raise RuntimeError("Load can only be used within a context manager") 

207 

208 headerInfo = self.deserializer.readHeaderInfo(self._readBytes(*self.headerBytesRange)) 

209 # verify this is the expected graph 

210 if graphID is not None and headerInfo._buildId != graphID: 

211 raise ValueError("graphID does not match that of the graph being loaded") 

212 # Read in specified nodes, or all the nodes 

213 nodeSet: set[UUID] 

214 if nodes is None: 

215 nodeSet = set(headerInfo.map.keys()) 

216 else: 

217 # only some bytes are being read using the reader specialized for 

218 # this class 

219 # create a set to ensure nodes are only loaded once 

220 nodeSet = {UUID(n) if isinstance(n, str) else n for n in nodes} 

221 # verify that all nodes requested are in the graph 

222 remainder = nodeSet - headerInfo.map.keys() 

223 if remainder: 

224 raise ValueError( 

225 f"Nodes {remainder} were requested, but could not be found in the input graph" 

226 ) 

227 _readBytes = self._readBytes 

228 if universe is None: 

229 universe = headerInfo.universe 

230 # use the daf butler context vars to aid in ensuring deduplication in 

231 # object instantiation. 

232 runner = PersistenceContextVars() 

233 graph = runner.run(self.deserializer.constructGraph, nodeSet, _readBytes, universe) 

234 return graph 

235 

236 def _readBytes(self, start: int, stop: int) -> bytes: 

237 """Load the specified byte range from the ResourcePath object 

238 

239 Parameters 

240 ---------- 

241 start : `int` 

242 The beginning byte location to read 

243 stop : `int` 

244 The end byte location to read 

245 

246 Returns 

247 ------- 

248 result : `bytes` 

249 The byte range specified from the 

250 `~lsst.resources.ResourceHandleProtocol`. 

251 

252 Raises 

253 ------ 

254 RuntimeError 

255 Raise if the method was not called from within a context block. 

256 """ 

257 if self._resourceHandle is None: 

258 raise RuntimeError("_readBytes must be called from within a context block") 

259 self._resourceHandle.seek(start) 

260 return self._resourceHandle.read(stop - start) 

261 

262 def __enter__(self) -> LoadHelper: 

263 if isinstance(self.uri, BinaryIO | BytesIO | BufferedRandom): 

264 self._resourceHandle = self.uri 

265 elif self.fullRead: 

266 local = self._exitStack.enter_context(self.uri.as_local()) 

267 self._resourceHandle = self._exitStack.enter_context(local.open("rb")) 

268 else: 

269 self._resourceHandle = self._exitStack.enter_context(self.uri.open("rb")) 

270 self._initialize() 

271 return self 

272 

273 def __exit__( 

274 self, 

275 type: type[BaseException] | None, 

276 value: BaseException | None, 

277 traceback: TracebackType | None, 

278 ) -> None: 

279 assert self._resourceHandle is not None 

280 self._exitStack.close() 

281 self._resourceHandle = None 

282 

283 def readHeader(self) -> str | None: 

284 with self as handle: 

285 result = handle.deserializer.unpackHeader(self._readBytes(*self.headerBytesRange)) 

286 return result