Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 26%

82 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-12 02:03 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("LoadHelper",) 

24 

25import struct 

26from contextlib import ExitStack 

27from dataclasses import dataclass 

28from io import BufferedRandom, BytesIO 

29from types import TracebackType 

30from typing import TYPE_CHECKING, BinaryIO, ContextManager, Iterable, Optional, Set, Type, Union 

31from uuid import UUID 

32 

33from lsst.daf.butler import DimensionUniverse 

34from lsst.resources import ResourceHandleProtocol, ResourcePath 

35 

36if TYPE_CHECKING: 36 ↛ 37line 36 didn't jump to line 37, because the condition on line 36 was never true

37 from ._versionDeserializers import DeserializerBase 

38 from .graph import QuantumGraph 

39 

40 

41@dataclass 

42class LoadHelper(ContextManager["LoadHelper"]): 

43 """This is a helper class to assist with selecting the appropriate loader 

44 and managing any contexts that may be needed. 

45 

46 This helper will raise a `ValueError` if the specified file does not appear 

47 to be a valid `QuantumGraph` save file. 

48 """ 

49 

50 uri: Union[ResourcePath, BinaryIO] 

51 """ResourcePath object from which the `QuantumGraph` is to be loaded 

52 """ 

53 minimumVersion: int 

54 """ 

55 Minimum version of a save file to load. Set to -1 to load all 

56 versions. Older versions may need to be loaded, and re-saved 

57 to upgrade them to the latest format before they can be used in 

58 production. 

59 """ 

60 

61 def __post_init__(self) -> None: 

62 self._resourceHandle: Optional[ResourceHandleProtocol] = None 

63 self._exitStack = ExitStack() 

64 

65 def _initialize(self) -> None: 

66 # need to import here to avoid cyclic imports 

67 from ._versionDeserializers import DESERIALIZER_MAP 

68 from .graph import MAGIC_BYTES, STRUCT_FMT_BASE 

69 

70 # Read the first few bytes which correspond to the magic identifier 

71 # bytes, and save version 

72 magicSize = len(MAGIC_BYTES) 

73 # read in just the fmt base to determine the save version 

74 fmtSize = struct.calcsize(STRUCT_FMT_BASE) 

75 preambleSize = magicSize + fmtSize 

76 headerBytes = self._readBytes(0, preambleSize) 

77 magic = headerBytes[:magicSize] 

78 versionBytes = headerBytes[magicSize:] 

79 

80 save_version = self._validateSave(magic, versionBytes) 

81 

82 # select the appropriate deserializer for this save version 

83 deserializerClass = DESERIALIZER_MAP[save_version] 

84 

85 # read in the bytes corresponding to the mappings and initialize the 

86 # deserializer. This will be the bytes that describe the following 

87 # byte boundaries of the header info 

88 sizeBytes = self._readBytes(preambleSize, preambleSize + deserializerClass.structSize) 

89 # DeserializerBase subclasses are required to have the same constructor 

90 # signature as the base class itself, but there is no way to express 

91 # this in the type system, so we just tell MyPy to ignore it. 

92 self.deserializer: DeserializerBase = deserializerClass(preambleSize, sizeBytes) 

93 # get the header byte range for later reading 

94 self.headerBytesRange = (preambleSize + deserializerClass.structSize, self.deserializer.headerSize) 

95 

96 def _validateSave(self, magic: bytes, versionBytes: bytes) -> int: 

97 """Implement validation on input file, prior to attempting to load it 

98 

99 Paramters 

100 --------- 

101 magic : `bytes` 

102 The first few bytes of the file, used to verify it is a 

103 QuantumGraph save file 

104 versionBytes : `bytes` 

105 The next few bytes from the beginning of the file, used to parse 

106 which version of the QuantumGraph file the save corresponds to 

107 

108 Returns 

109 ------- 

110 save_version : `int` 

111 The save version parsed from the supplied bytes 

112 

113 Raises 

114 ------ 

115 ValueError 

116 Raised if the specified file contains the wrong file signature and 

117 is not a `QuantumGraph` save, or if the graph save version is 

118 below the minimum specified version. 

119 """ 

120 from .graph import MAGIC_BYTES, SAVE_VERSION, STRUCT_FMT_BASE 

121 

122 if magic != MAGIC_BYTES: 

123 raise ValueError( 

124 "This file does not appear to be a quantum graph save got magic bytes " 

125 f"{magic!r}, expected {MAGIC_BYTES!r}" 

126 ) 

127 

128 # unpack the save version bytes and verify it is a version that this 

129 # code can understand 

130 (save_version,) = struct.unpack(STRUCT_FMT_BASE, versionBytes) 

131 # loads can sometimes trigger upgrades in format to a latest version, 

132 # in which case accessory code might not match the upgraded graph. 

133 # I.E. switching from old node number to UUID. This clause necessitates 

134 # that users specifically interact with older graph versions and verify 

135 # everything happens appropriately. 

136 if save_version < self.minimumVersion: 

137 raise ValueError( 

138 f"The loaded QuantumGraph is version {save_version}, and the minimum " 

139 f"version specified is {self.minimumVersion}. Please re-run this method " 

140 "with a lower minimum version, then re-save the graph to automatically upgrade" 

141 "to the newest version. Older versions may not work correctly with newer code" 

142 ) 

143 

144 if save_version > SAVE_VERSION: 

145 raise ValueError( 

146 f"The version of this save file is {save_version}, but this version of" 

147 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}" 

148 ) 

149 return save_version 

150 

151 def load( 

152 self, 

153 universe: Optional[DimensionUniverse] = None, 

154 nodes: Optional[Iterable[Union[UUID, str]]] = None, 

155 graphID: Optional[str] = None, 

156 ) -> QuantumGraph: 

157 """Loads in the specified nodes from the graph 

158 

159 Load in the `QuantumGraph` containing only the nodes specified in the 

160 ``nodes`` parameter from the graph specified at object creation. If 

161 ``nodes`` is None (the default) the whole graph is loaded. 

162 

163 Parameters 

164 ---------- 

165 universe: `~lsst.daf.butler.DimensionUniverse` or None 

166 DimensionUniverse instance, not used by the method itself but 

167 needed to ensure that registry data structures are initialized. 

168 The universe saved with the graph is used, but if one is passed 

169 it will be used to validate the compatibility with the loaded 

170 graph universe. 

171 nodes : `Iterable` of `UUID` or `str`; or `None` 

172 The nodes to load from the graph, loads all if value is None 

173 (the default) 

174 graphID : `str` or `None` 

175 If specified this ID is verified against the loaded graph prior to 

176 loading any Nodes. This defaults to None in which case no 

177 validation is done. 

178 

179 Returns 

180 ------- 

181 graph : `QuantumGraph` 

182 The loaded `QuantumGraph` object 

183 

184 Raises 

185 ------ 

186 ValueError 

187 Raised if one or more of the nodes requested is not in the 

188 `QuantumGraph` or if graphID parameter does not match the graph 

189 being loaded. 

190 RuntimeError 

191 Raised if Supplied DimensionUniverse is not compatible with the 

192 DimensionUniverse saved in the graph 

193 Raised if the method was not called from within a context block 

194 """ 

195 if self._resourceHandle is None: 

196 raise RuntimeError("Load can only be used within a context manager") 

197 

198 headerInfo = self.deserializer.readHeaderInfo(self._readBytes(*self.headerBytesRange)) 

199 # verify this is the expected graph 

200 if graphID is not None and headerInfo._buildId != graphID: 

201 raise ValueError("graphID does not match that of the graph being loaded") 

202 # Read in specified nodes, or all the nodes 

203 nodeSet: Set[UUID] 

204 if nodes is None: 

205 nodeSet = set(headerInfo.map.keys()) 

206 else: 

207 # only some bytes are being read using the reader specialized for 

208 # this class 

209 # create a set to ensure nodes are only loaded once 

210 nodeSet = {UUID(n) if isinstance(n, str) else n for n in nodes} 

211 # verify that all nodes requested are in the graph 

212 remainder = nodeSet - headerInfo.map.keys() 

213 if remainder: 

214 raise ValueError( 

215 f"Nodes {remainder} were requested, but could not be found in the input graph" 

216 ) 

217 _readBytes = self._readBytes 

218 if universe is None: 

219 universe = headerInfo.universe 

220 return self.deserializer.constructGraph(nodeSet, _readBytes, universe) 

221 

222 def _readBytes(self, start: int, stop: int) -> bytes: 

223 """Load the specified byte range from the ResourcePath object 

224 

225 Parameters 

226 ---------- 

227 start : `int` 

228 The beginning byte location to read 

229 end : `int` 

230 The end byte location to read 

231 

232 Returns 

233 ------- 

234 result : `bytes` 

235 The byte range specified from the `ResourceHandle` 

236 

237 Raises 

238 ------ 

239 RuntimeError 

240 Raise if the method was not called from within a context block 

241 """ 

242 if self._resourceHandle is None: 

243 raise RuntimeError("_readBytes must be called from within a context block") 

244 self._resourceHandle.seek(start) 

245 return self._resourceHandle.read(stop - start) 

246 

247 def __enter__(self) -> "LoadHelper": 

248 if isinstance(self.uri, (BinaryIO, BytesIO, BufferedRandom)): 

249 self._resourceHandle = self.uri 

250 else: 

251 self._resourceHandle = self._exitStack.enter_context(self.uri.open("rb")) 

252 self._initialize() 

253 return self 

254 

255 def __exit__( 

256 self, 

257 type: Optional[Type[BaseException]], 

258 value: Optional[BaseException], 

259 traceback: Optional[TracebackType], 

260 ) -> None: 

261 assert self._resourceHandle is not None 

262 self._exitStack.close() 

263 self._resourceHandle = None 

264 

265 def readHeader(self) -> Optional[str]: 

266 with self as handle: 

267 result = handle.deserializer.unpackHeader(self._readBytes(*self.headerBytesRange)) 

268 return result