Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 26%

82 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-12 11:14 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("LoadHelper",) 

24 

25import struct 

26from collections.abc import Iterable 

27from contextlib import AbstractContextManager, ExitStack 

28from dataclasses import dataclass 

29from io import BufferedRandom, BytesIO 

30from types import TracebackType 

31from typing import TYPE_CHECKING, BinaryIO 

32from uuid import UUID 

33 

34from lsst.daf.butler import DimensionUniverse, PersistenceContextVars 

35from lsst.resources import ResourceHandleProtocol, ResourcePath 

36 

37if TYPE_CHECKING: 

38 from ._versionDeserializers import DeserializerBase 

39 from .graph import QuantumGraph 

40 

41 

42@dataclass 

43class LoadHelper(AbstractContextManager["LoadHelper"]): 

44 """Helper class to assist with selecting the appropriate loader 

45 and managing any contexts that may be needed. 

46 

47 This helper will raise a `ValueError` if the specified file does not appear 

48 to be a valid `QuantumGraph` save file. 

49 """ 

50 

51 uri: ResourcePath | BinaryIO 

52 """ResourcePath object from which the `QuantumGraph` is to be loaded 

53 """ 

54 minimumVersion: int 

55 """ 

56 Minimum version of a save file to load. Set to -1 to load all 

57 versions. Older versions may need to be loaded, and re-saved 

58 to upgrade them to the latest format before they can be used in 

59 production. 

60 """ 

61 

62 def __post_init__(self) -> None: 

63 self._resourceHandle: ResourceHandleProtocol | None = None 

64 self._exitStack = ExitStack() 

65 

66 def _initialize(self) -> None: 

67 # need to import here to avoid cyclic imports 

68 from ._versionDeserializers import DESERIALIZER_MAP 

69 from .graph import MAGIC_BYTES, STRUCT_FMT_BASE 

70 

71 # Read the first few bytes which correspond to the magic identifier 

72 # bytes, and save version 

73 magicSize = len(MAGIC_BYTES) 

74 # read in just the fmt base to determine the save version 

75 fmtSize = struct.calcsize(STRUCT_FMT_BASE) 

76 preambleSize = magicSize + fmtSize 

77 headerBytes = self._readBytes(0, preambleSize) 

78 magic = headerBytes[:magicSize] 

79 versionBytes = headerBytes[magicSize:] 

80 

81 save_version = self._validateSave(magic, versionBytes) 

82 

83 # select the appropriate deserializer for this save version 

84 deserializerClass = DESERIALIZER_MAP[save_version] 

85 

86 # read in the bytes corresponding to the mappings and initialize the 

87 # deserializer. This will be the bytes that describe the following 

88 # byte boundaries of the header info 

89 sizeBytes = self._readBytes(preambleSize, preambleSize + deserializerClass.structSize) 

90 # DeserializerBase subclasses are required to have the same constructor 

91 # signature as the base class itself, but there is no way to express 

92 # this in the type system, so we just tell MyPy to ignore it. 

93 self.deserializer: DeserializerBase = deserializerClass(preambleSize, sizeBytes) 

94 # get the header byte range for later reading 

95 self.headerBytesRange = (preambleSize + deserializerClass.structSize, self.deserializer.headerSize) 

96 

97 def _validateSave(self, magic: bytes, versionBytes: bytes) -> int: 

98 """Implement validation on input file, prior to attempting to load it 

99 

100 Paramters 

101 --------- 

102 magic : `bytes` 

103 The first few bytes of the file, used to verify it is a 

104 `QuantumGraph` save file. 

105 versionBytes : `bytes` 

106 The next few bytes from the beginning of the file, used to parse 

107 which version of the `QuantumGraph` file the save corresponds to. 

108 

109 Returns 

110 ------- 

111 save_version : `int` 

112 The save version parsed from the supplied bytes. 

113 

114 Raises 

115 ------ 

116 ValueError 

117 Raised if the specified file contains the wrong file signature and 

118 is not a `QuantumGraph` save, or if the graph save version is 

119 below the minimum specified version. 

120 """ 

121 from .graph import MAGIC_BYTES, SAVE_VERSION, STRUCT_FMT_BASE 

122 

123 if magic != MAGIC_BYTES: 

124 raise ValueError( 

125 "This file does not appear to be a quantum graph save got magic bytes " 

126 f"{magic!r}, expected {MAGIC_BYTES!r}" 

127 ) 

128 

129 # unpack the save version bytes and verify it is a version that this 

130 # code can understand 

131 (save_version,) = struct.unpack(STRUCT_FMT_BASE, versionBytes) 

132 # loads can sometimes trigger upgrades in format to a latest version, 

133 # in which case accessory code might not match the upgraded graph. 

134 # I.E. switching from old node number to UUID. This clause necessitates 

135 # that users specifically interact with older graph versions and verify 

136 # everything happens appropriately. 

137 if save_version < self.minimumVersion: 

138 raise ValueError( 

139 f"The loaded QuantumGraph is version {save_version}, and the minimum " 

140 f"version specified is {self.minimumVersion}. Please re-run this method " 

141 "with a lower minimum version, then re-save the graph to automatically upgrade" 

142 "to the newest version. Older versions may not work correctly with newer code" 

143 ) 

144 

145 if save_version > SAVE_VERSION: 

146 raise ValueError( 

147 f"The version of this save file is {save_version}, but this version of" 

148 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}" 

149 ) 

150 return save_version 

151 

152 def load( 

153 self, 

154 universe: DimensionUniverse | None = None, 

155 nodes: Iterable[UUID | str] | None = None, 

156 graphID: str | None = None, 

157 ) -> QuantumGraph: 

158 """Load in the specified nodes from the graph. 

159 

160 Load in the `QuantumGraph` containing only the nodes specified in the 

161 ``nodes`` parameter from the graph specified at object creation. If 

162 ``nodes`` is None (the default) the whole graph is loaded. 

163 

164 Parameters 

165 ---------- 

166 universe: `~lsst.daf.butler.DimensionUniverse` or None 

167 DimensionUniverse instance, not used by the method itself but 

168 needed to ensure that registry data structures are initialized. 

169 The universe saved with the graph is used, but if one is passed 

170 it will be used to validate the compatibility with the loaded 

171 graph universe. 

172 nodes : `~collections.abc.Iterable` of `UUID` or `str`; or `None` 

173 The nodes to load from the graph, loads all if value is None 

174 (the default) 

175 graphID : `str` or `None` 

176 If specified this ID is verified against the loaded graph prior to 

177 loading any Nodes. This defaults to None in which case no 

178 validation is done. 

179 

180 Returns 

181 ------- 

182 graph : `QuantumGraph` 

183 The loaded `QuantumGraph` object. 

184 

185 Raises 

186 ------ 

187 ValueError 

188 Raised if one or more of the nodes requested is not in the 

189 `QuantumGraph` or if graphID parameter does not match the graph 

190 being loaded. 

191 RuntimeError 

192 Raised if supplied `~lsst.daf.butler.DimensionUniverse` is not 

193 compatible with the `~lsst.daf.butler.DimensionUniverse` saved in 

194 the graph. Raised if the method was not called from within a 

195 context block. 

196 """ 

197 if self._resourceHandle is None: 

198 raise RuntimeError("Load can only be used within a context manager") 

199 

200 headerInfo = self.deserializer.readHeaderInfo(self._readBytes(*self.headerBytesRange)) 

201 # verify this is the expected graph 

202 if graphID is not None and headerInfo._buildId != graphID: 

203 raise ValueError("graphID does not match that of the graph being loaded") 

204 # Read in specified nodes, or all the nodes 

205 nodeSet: set[UUID] 

206 if nodes is None: 

207 nodeSet = set(headerInfo.map.keys()) 

208 else: 

209 # only some bytes are being read using the reader specialized for 

210 # this class 

211 # create a set to ensure nodes are only loaded once 

212 nodeSet = {UUID(n) if isinstance(n, str) else n for n in nodes} 

213 # verify that all nodes requested are in the graph 

214 remainder = nodeSet - headerInfo.map.keys() 

215 if remainder: 

216 raise ValueError( 

217 f"Nodes {remainder} were requested, but could not be found in the input graph" 

218 ) 

219 _readBytes = self._readBytes 

220 if universe is None: 

221 universe = headerInfo.universe 

222 # use the daf butler context vars to aid in ensuring deduplication in 

223 # object instantiation. 

224 runner = PersistenceContextVars() 

225 graph = runner.run(self.deserializer.constructGraph, nodeSet, _readBytes, universe) 

226 return graph 

227 

228 def _readBytes(self, start: int, stop: int) -> bytes: 

229 """Load the specified byte range from the ResourcePath object 

230 

231 Parameters 

232 ---------- 

233 start : `int` 

234 The beginning byte location to read 

235 end : `int` 

236 The end byte location to read 

237 

238 Returns 

239 ------- 

240 result : `bytes` 

241 The byte range specified from the 

242 `~lsst.resources.ResourceHandleProtocol`. 

243 

244 Raises 

245 ------ 

246 RuntimeError 

247 Raise if the method was not called from within a context block. 

248 """ 

249 if self._resourceHandle is None: 

250 raise RuntimeError("_readBytes must be called from within a context block") 

251 self._resourceHandle.seek(start) 

252 return self._resourceHandle.read(stop - start) 

253 

254 def __enter__(self) -> "LoadHelper": 

255 if isinstance(self.uri, (BinaryIO, BytesIO, BufferedRandom)): 

256 self._resourceHandle = self.uri 

257 else: 

258 self._resourceHandle = self._exitStack.enter_context(self.uri.open("rb")) 

259 self._initialize() 

260 return self 

261 

262 def __exit__( 

263 self, 

264 type: type[BaseException] | None, 

265 value: BaseException | None, 

266 traceback: TracebackType | None, 

267 ) -> None: 

268 assert self._resourceHandle is not None 

269 self._exitStack.close() 

270 self._resourceHandle = None 

271 

272 def readHeader(self) -> str | None: 

273 with self as handle: 

274 result = handle.deserializer.unpackHeader(self._readBytes(*self.headerBytesRange)) 

275 return result