Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 26%

82 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-19 10:39 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ("LoadHelper",) 

30 

31import struct 

32from collections.abc import Iterable 

33from contextlib import AbstractContextManager, ExitStack 

34from dataclasses import dataclass 

35from io import BufferedRandom, BytesIO 

36from types import TracebackType 

37from typing import TYPE_CHECKING, BinaryIO 

38from uuid import UUID 

39 

40from lsst.daf.butler import DimensionUniverse, PersistenceContextVars 

41from lsst.resources import ResourceHandleProtocol, ResourcePath 

42 

43if TYPE_CHECKING: 

44 from ._versionDeserializers import DeserializerBase 

45 from .graph import QuantumGraph 

46 

47 

48@dataclass 

49class LoadHelper(AbstractContextManager["LoadHelper"]): 

50 """Helper class to assist with selecting the appropriate loader 

51 and managing any contexts that may be needed. 

52 

53 This helper will raise a `ValueError` if the specified file does not appear 

54 to be a valid `QuantumGraph` save file. 

55 """ 

56 

57 uri: ResourcePath | BinaryIO 

58 """ResourcePath object from which the `QuantumGraph` is to be loaded 

59 """ 

60 minimumVersion: int 

61 """ 

62 Minimum version of a save file to load. Set to -1 to load all 

63 versions. Older versions may need to be loaded, and re-saved 

64 to upgrade them to the latest format before they can be used in 

65 production. 

66 """ 

67 

68 def __post_init__(self) -> None: 

69 self._resourceHandle: ResourceHandleProtocol | None = None 

70 self._exitStack = ExitStack() 

71 

72 def _initialize(self) -> None: 

73 # need to import here to avoid cyclic imports 

74 from ._versionDeserializers import DESERIALIZER_MAP 

75 from .graph import MAGIC_BYTES, STRUCT_FMT_BASE 

76 

77 # Read the first few bytes which correspond to the magic identifier 

78 # bytes, and save version 

79 magicSize = len(MAGIC_BYTES) 

80 # read in just the fmt base to determine the save version 

81 fmtSize = struct.calcsize(STRUCT_FMT_BASE) 

82 preambleSize = magicSize + fmtSize 

83 headerBytes = self._readBytes(0, preambleSize) 

84 magic = headerBytes[:magicSize] 

85 versionBytes = headerBytes[magicSize:] 

86 

87 save_version = self._validateSave(magic, versionBytes) 

88 

89 # select the appropriate deserializer for this save version 

90 deserializerClass = DESERIALIZER_MAP[save_version] 

91 

92 # read in the bytes corresponding to the mappings and initialize the 

93 # deserializer. This will be the bytes that describe the following 

94 # byte boundaries of the header info 

95 sizeBytes = self._readBytes(preambleSize, preambleSize + deserializerClass.structSize) 

96 # DeserializerBase subclasses are required to have the same constructor 

97 # signature as the base class itself, but there is no way to express 

98 # this in the type system, so we just tell MyPy to ignore it. 

99 self.deserializer: DeserializerBase = deserializerClass(preambleSize, sizeBytes) 

100 # get the header byte range for later reading 

101 self.headerBytesRange = (preambleSize + deserializerClass.structSize, self.deserializer.headerSize) 

102 

103 def _validateSave(self, magic: bytes, versionBytes: bytes) -> int: 

104 """Implement validation on input file, prior to attempting to load it 

105 

106 Paramters 

107 --------- 

108 magic : `bytes` 

109 The first few bytes of the file, used to verify it is a 

110 `QuantumGraph` save file. 

111 versionBytes : `bytes` 

112 The next few bytes from the beginning of the file, used to parse 

113 which version of the `QuantumGraph` file the save corresponds to. 

114 

115 Returns 

116 ------- 

117 save_version : `int` 

118 The save version parsed from the supplied bytes. 

119 

120 Raises 

121 ------ 

122 ValueError 

123 Raised if the specified file contains the wrong file signature and 

124 is not a `QuantumGraph` save, or if the graph save version is 

125 below the minimum specified version. 

126 """ 

127 from .graph import MAGIC_BYTES, SAVE_VERSION, STRUCT_FMT_BASE 

128 

129 if magic != MAGIC_BYTES: 

130 raise ValueError( 

131 "This file does not appear to be a quantum graph save got magic bytes " 

132 f"{magic!r}, expected {MAGIC_BYTES!r}" 

133 ) 

134 

135 # unpack the save version bytes and verify it is a version that this 

136 # code can understand 

137 (save_version,) = struct.unpack(STRUCT_FMT_BASE, versionBytes) 

138 # loads can sometimes trigger upgrades in format to a latest version, 

139 # in which case accessory code might not match the upgraded graph. 

140 # I.E. switching from old node number to UUID. This clause necessitates 

141 # that users specifically interact with older graph versions and verify 

142 # everything happens appropriately. 

143 if save_version < self.minimumVersion: 

144 raise ValueError( 

145 f"The loaded QuantumGraph is version {save_version}, and the minimum " 

146 f"version specified is {self.minimumVersion}. Please re-run this method " 

147 "with a lower minimum version, then re-save the graph to automatically upgrade" 

148 "to the newest version. Older versions may not work correctly with newer code" 

149 ) 

150 

151 if save_version > SAVE_VERSION: 

152 raise ValueError( 

153 f"The version of this save file is {save_version}, but this version of" 

154 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}" 

155 ) 

156 return save_version 

157 

158 def load( 

159 self, 

160 universe: DimensionUniverse | None = None, 

161 nodes: Iterable[UUID | str] | None = None, 

162 graphID: str | None = None, 

163 ) -> QuantumGraph: 

164 """Load in the specified nodes from the graph. 

165 

166 Load in the `QuantumGraph` containing only the nodes specified in the 

167 ``nodes`` parameter from the graph specified at object creation. If 

168 ``nodes`` is None (the default) the whole graph is loaded. 

169 

170 Parameters 

171 ---------- 

172 universe: `~lsst.daf.butler.DimensionUniverse` or None 

173 DimensionUniverse instance, not used by the method itself but 

174 needed to ensure that registry data structures are initialized. 

175 The universe saved with the graph is used, but if one is passed 

176 it will be used to validate the compatibility with the loaded 

177 graph universe. 

178 nodes : `~collections.abc.Iterable` of `UUID` or `str`; or `None` 

179 The nodes to load from the graph, loads all if value is None 

180 (the default) 

181 graphID : `str` or `None` 

182 If specified this ID is verified against the loaded graph prior to 

183 loading any Nodes. This defaults to None in which case no 

184 validation is done. 

185 

186 Returns 

187 ------- 

188 graph : `QuantumGraph` 

189 The loaded `QuantumGraph` object. 

190 

191 Raises 

192 ------ 

193 ValueError 

194 Raised if one or more of the nodes requested is not in the 

195 `QuantumGraph` or if graphID parameter does not match the graph 

196 being loaded. 

197 RuntimeError 

198 Raised if supplied `~lsst.daf.butler.DimensionUniverse` is not 

199 compatible with the `~lsst.daf.butler.DimensionUniverse` saved in 

200 the graph. Raised if the method was not called from within a 

201 context block. 

202 """ 

203 if self._resourceHandle is None: 

204 raise RuntimeError("Load can only be used within a context manager") 

205 

206 headerInfo = self.deserializer.readHeaderInfo(self._readBytes(*self.headerBytesRange)) 

207 # verify this is the expected graph 

208 if graphID is not None and headerInfo._buildId != graphID: 

209 raise ValueError("graphID does not match that of the graph being loaded") 

210 # Read in specified nodes, or all the nodes 

211 nodeSet: set[UUID] 

212 if nodes is None: 

213 nodeSet = set(headerInfo.map.keys()) 

214 else: 

215 # only some bytes are being read using the reader specialized for 

216 # this class 

217 # create a set to ensure nodes are only loaded once 

218 nodeSet = {UUID(n) if isinstance(n, str) else n for n in nodes} 

219 # verify that all nodes requested are in the graph 

220 remainder = nodeSet - headerInfo.map.keys() 

221 if remainder: 

222 raise ValueError( 

223 f"Nodes {remainder} were requested, but could not be found in the input graph" 

224 ) 

225 _readBytes = self._readBytes 

226 if universe is None: 

227 universe = headerInfo.universe 

228 # use the daf butler context vars to aid in ensuring deduplication in 

229 # object instantiation. 

230 runner = PersistenceContextVars() 

231 graph = runner.run(self.deserializer.constructGraph, nodeSet, _readBytes, universe) 

232 return graph 

233 

234 def _readBytes(self, start: int, stop: int) -> bytes: 

235 """Load the specified byte range from the ResourcePath object 

236 

237 Parameters 

238 ---------- 

239 start : `int` 

240 The beginning byte location to read 

241 end : `int` 

242 The end byte location to read 

243 

244 Returns 

245 ------- 

246 result : `bytes` 

247 The byte range specified from the 

248 `~lsst.resources.ResourceHandleProtocol`. 

249 

250 Raises 

251 ------ 

252 RuntimeError 

253 Raise if the method was not called from within a context block. 

254 """ 

255 if self._resourceHandle is None: 

256 raise RuntimeError("_readBytes must be called from within a context block") 

257 self._resourceHandle.seek(start) 

258 return self._resourceHandle.read(stop - start) 

259 

260 def __enter__(self) -> "LoadHelper": 

261 if isinstance(self.uri, BinaryIO | BytesIO | BufferedRandom): 

262 self._resourceHandle = self.uri 

263 else: 

264 self._resourceHandle = self._exitStack.enter_context(self.uri.open("rb")) 

265 self._initialize() 

266 return self 

267 

268 def __exit__( 

269 self, 

270 type: type[BaseException] | None, 

271 value: BaseException | None, 

272 traceback: TracebackType | None, 

273 ) -> None: 

274 assert self._resourceHandle is not None 

275 self._exitStack.close() 

276 self._resourceHandle = None 

277 

278 def readHeader(self) -> str | None: 

279 with self as handle: 

280 result = handle.deserializer.unpackHeader(self._readBytes(*self.headerBytesRange)) 

281 return result