Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 39%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

124 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23from uuid import UUID 

24 

25__all__ = ("LoadHelper",) 

26 

27import functools 

28import io 

29import struct 

30from collections import UserDict 

31from dataclasses import dataclass 

32from typing import IO, TYPE_CHECKING, Iterable, Optional, Type, Union 

33 

34from lsst.daf.butler import DimensionUniverse 

35from lsst.resources import ResourcePath 

36from lsst.resources.file import FileResourcePath 

37from lsst.resources.s3 import S3ResourcePath 

38 

39if TYPE_CHECKING: 39 ↛ 40line 39 didn't jump to line 40, because the condition on line 39 was never true

40 from . import QuantumGraph 

41 

42 

43# Create a custom dict that will return the desired default if a key is missing 

44class RegistryDict(UserDict): 

45 def __missing__(self, key): 

46 return DefaultLoadHelper 

47 

48 

49# Create a registry to hold all the load Helper classes 

50HELPER_REGISTRY = RegistryDict() 

51 

52 

53def register_helper(URICLass: Union[Type[ResourcePath], Type[io.IO[bytes]]]): 

54 """Used to register classes as Load helpers 

55 

56 When decorating a class the parameter is the class of "handle type", i.e. 

57 a ResourcePath type or open file handle that will be used to do the 

58 loading. This is then associated with the decorated class such that when 

59 the parameter type is used to load data, the appropriate helper to work 

60 with that data type can be returned. 

61 

62 A decorator is used so that in theory someone could define another handler 

63 in a different module and register it for use. 

64 

65 Parameters 

66 ---------- 

67 URIClass : Type of `~lsst.resources.ResourcePath` or `~io.IO` of bytes 

68 type for which the decorated class should be mapped to 

69 """ 

70 

71 def wrapper(class_): 

72 HELPER_REGISTRY[URICLass] = class_ 

73 return class_ 

74 

75 return wrapper 

76 

77 

78class DefaultLoadHelper: 

79 """Default load helper for `QuantumGraph` save files 

80 

81 This class, and its subclasses, are used to unpack a quantum graph save 

82 file. This file is a binary representation of the graph in a format that 

83 allows individual nodes to be loaded without needing to load the entire 

84 file. 

85 

86 This default implementation has the interface to load select nodes 

87 from disk, but actually always loads the entire save file and simply 

88 returns what nodes (or all) are requested. This is intended to serve for 

89 all cases where there is a read method on the input parameter, but it is 

90 unknown how to read select bytes of the stream. It is the responsibility of 

91 sub classes to implement the method responsible for loading individual 

92 bytes from the stream. 

93 

94 Parameters 

95 ---------- 

96 uriObject : `~lsst.resources.ResourcePath` or `io.IO` of bytes 

97 This is the object that will be used to retrieve the raw bytes of the 

98 save. 

99 minimumVersion : `int` 

100 Minimum version of a save file to load. Set to -1 to load all 

101 versions. Older versions may need to be loaded, and re-saved 

102 to upgrade them to the latest format. This upgrade may not happen 

103 deterministically each time an older graph format is loaded. Because 

104 of this behavior, the minimumVersion parameter, forces a user to 

105 interact manually and take this into account before they can be used in 

106 production. 

107 

108 Raises 

109 ------ 

110 ValueError 

111 Raised if the specified file contains the wrong file signature and is 

112 not a `QuantumGraph` save, or if the graph save version is below the 

113 minimum specified version. 

114 """ 

115 

116 def __init__(self, uriObject: Union[ResourcePath, io.IO[bytes]], minimumVersion: int): 

117 headerBytes = self.__setup_impl(uriObject, minimumVersion) 

118 self.headerInfo = self.deserializer.readHeaderInfo(headerBytes) 

119 

120 def __setup_impl(self, uriObject: Union[ResourcePath, io.IO[bytes]], minimumVersion: int) -> bytes: 

121 self.uriObject = uriObject 

122 # need to import here to avoid cyclic imports 

123 from ._versionDeserializers import DESERIALIZER_MAP 

124 from .graph import MAGIC_BYTES, SAVE_VERSION, STRUCT_FMT_BASE 

125 

126 # Read the first few bytes which correspond to the magic identifier 

127 # bytes, and save version 

128 magicSize = len(MAGIC_BYTES) 

129 # read in just the fmt base to determine the save version 

130 fmtSize = struct.calcsize(STRUCT_FMT_BASE) 

131 preambleSize = magicSize + fmtSize 

132 headerBytes = self._readBytes(0, preambleSize) 

133 magic = headerBytes[:magicSize] 

134 versionBytes = headerBytes[magicSize:] 

135 

136 if magic != MAGIC_BYTES: 

137 raise ValueError( 

138 "This file does not appear to be a quantum graph save got magic bytes " 

139 f"{magic}, expected {MAGIC_BYTES}" 

140 ) 

141 

142 # unpack the save version bytes and verify it is a version that this 

143 # code can understand 

144 (save_version,) = struct.unpack(STRUCT_FMT_BASE, versionBytes) 

145 # loads can sometimes trigger upgrades in format to a latest version, 

146 # in which case accessory code might not match the upgraded graph. 

147 # I.E. switching from old node number to UUID. This clause necessitates 

148 # that users specifically interact with older graph versions and verify 

149 # everything happens appropriately. 

150 if save_version < minimumVersion: 

151 raise ValueError( 

152 f"The loaded QuantumGraph is version {save_version}, and the minimum " 

153 f"version specified is {minimumVersion}. Please re-run this method " 

154 "with a lower minimum version, then re-save the graph to automatically upgrade" 

155 "to the newest version. Older versions may not work correctly with newer code" 

156 ) 

157 

158 if save_version > SAVE_VERSION: 

159 raise RuntimeError( 

160 f"The version of this save file is {save_version}, but this version of" 

161 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}" 

162 ) 

163 

164 # select the appropriate deserializer for this save version 

165 deserializerClass = DESERIALIZER_MAP[save_version] 

166 

167 # read in the bytes corresponding to the mappings and initialize the 

168 # deserializer. This will be the bytes that describe the following 

169 # byte boundaries of the header info 

170 sizeBytes = self._readBytes(preambleSize, preambleSize + deserializerClass.structSize) 

171 self.deserializer = deserializerClass(preambleSize, sizeBytes) 

172 

173 # get the header info 

174 headerBytes = self._readBytes( 

175 preambleSize + deserializerClass.structSize, self.deserializer.headerSize 

176 ) 

177 return headerBytes 

178 

179 @classmethod 

180 def dumpHeader( 

181 cls, uriObject: Union[ResourcePath, io.IO[bytes]], minimumVersion: int = 3 

182 ) -> Optional[str]: 

183 instance = cls.__new__(cls) 

184 headerBytes = instance.__setup_impl(uriObject, minimumVersion) 

185 header = instance.deserializer.unpackHeader(headerBytes) 

186 instance.close() 

187 return header 

188 

189 def load( 

190 self, 

191 universe: DimensionUniverse, 

192 nodes: Optional[Iterable[Union[str, UUID]]] = None, 

193 graphID: Optional[str] = None, 

194 ) -> QuantumGraph: 

195 """Loads in the specified nodes from the graph 

196 

197 Load in the `QuantumGraph` containing only the nodes specified in the 

198 ``nodes`` parameter from the graph specified at object creation. If 

199 ``nodes`` is None (the default) the whole graph is loaded. 

200 

201 Parameters 

202 ---------- 

203 universe: `~lsst.daf.butler.DimensionUniverse` 

204 DimensionUniverse instance, not used by the method itself but 

205 needed to ensure that registry data structures are initialized. 

206 nodes : `Iterable` of `int` or `None` 

207 The nodes to load from the graph, loads all if value is None 

208 (the default) 

209 graphID : `str` or `None` 

210 If specified this ID is verified against the loaded graph prior to 

211 loading any Nodes. This defaults to None in which case no 

212 validation is done. 

213 

214 Returns 

215 ------- 

216 graph : `QuantumGraph` 

217 The loaded `QuantumGraph` object 

218 

219 Raises 

220 ------ 

221 ValueError 

222 Raised if one or more of the nodes requested is not in the 

223 `QuantumGraph` or if graphID parameter does not match the graph 

224 being loaded. 

225 """ 

226 # verify this is the expected graph 

227 if graphID is not None and self.headerInfo._buildId != graphID: 

228 raise ValueError("graphID does not match that of the graph being loaded") 

229 # Read in specified nodes, or all the nodes 

230 if nodes is None: 

231 nodes = list(self.headerInfo.map.keys()) 

232 # if all nodes are to be read, force the reader from the base class 

233 # that will read all they bytes in one go 

234 _readBytes = functools.partial(DefaultLoadHelper._readBytes, self) 

235 else: 

236 # only some bytes are being read using the reader specialized for 

237 # this class 

238 # create a set to ensure nodes are only loaded once 

239 nodes = {UUID(n) if isinstance(n, str) else n for n in nodes} 

240 # verify that all nodes requested are in the graph 

241 remainder = nodes - self.headerInfo.map.keys() 

242 if remainder: 

243 raise ValueError( 

244 f"Nodes {remainder} were requested, but could not be found in the input graph" 

245 ) 

246 _readBytes = self._readBytes 

247 return self.deserializer.constructGraph(nodes, _readBytes, universe) 

248 

249 def _readBytes(self, start: int, stop: int) -> bytes: 

250 """Loads the specified byte range from the ResourcePath object 

251 

252 In the base class, this actually will read all the bytes into a buffer 

253 from the specified ResourcePath object. Then for each method call will 

254 return the requested byte range. This is the most flexible 

255 implementation, as no special read is required. This will not give a 

256 speed up with any sub graph reads though. 

257 """ 

258 if not hasattr(self, "buffer"): 

259 self.buffer = self.uriObject.read() 

260 return self.buffer[start:stop] 

261 

262 def close(self): 

263 """Cleans up an instance if needed. Base class does nothing""" 

264 pass 

265 

266 

267@register_helper(S3ResourcePath) 

268class S3LoadHelper(DefaultLoadHelper): 

269 # This subclass implements partial loading of a graph using a s3 uri 

270 def _readBytes(self, start: int, stop: int) -> bytes: 

271 args = {} 

272 # minus 1 in the stop range, because this header is inclusive rather 

273 # than standard python where the end point is generally exclusive 

274 args["Range"] = f"bytes={start}-{stop-1}" 

275 try: 

276 response = self.uriObject.client.get_object( 

277 Bucket=self.uriObject.netloc, Key=self.uriObject.relativeToPathRoot, **args 

278 ) 

279 except ( 

280 self.uriObject.client.exceptions.NoSuchKey, 

281 self.uriObject.client.exceptions.NoSuchBucket, 

282 ) as err: 

283 raise FileNotFoundError(f"No such resource: {self.uriObject}") from err 

284 body = response["Body"].read() 

285 response["Body"].close() 

286 return body 

287 

288 

289@register_helper(FileResourcePath) 

290class FileLoadHelper(DefaultLoadHelper): 

291 # This subclass implements partial loading of a graph using a file uri 

292 def _readBytes(self, start: int, stop: int) -> bytes: 

293 if not hasattr(self, "fileHandle"): 

294 self.fileHandle = open(self.uriObject.ospath, "rb") 

295 self.fileHandle.seek(start) 

296 return self.fileHandle.read(stop - start) 

297 

298 def close(self): 

299 if hasattr(self, "fileHandle"): 

300 self.fileHandle.close() 

301 

302 

303@register_helper(io.IOBase) # type: ignore 

304class OpenFileHandleHelper(DefaultLoadHelper): 

305 # This handler is special in that it does not get initialized with a 

306 # ResourcePath, but an open file handle. 

307 

308 # Most everything stays the same, the variable is even stored as uriObject, 

309 # because the interface needed for reading is the same. Unfortunately 

310 # because we do not have Protocols yet, this can not be nicely expressed 

311 # with typing. 

312 

313 # This helper does support partial loading 

314 

315 def __init__(self, uriObject: io.IO[bytes], minimumVersion: int): 

316 # Explicitly annotate type and not infer from super 

317 self.uriObject: io.IO[bytes] 

318 super().__init__(uriObject, minimumVersion=minimumVersion) 

319 # This differs from the default __init__ to force the io object 

320 # back to the beginning so that in the case the entire file is to 

321 # read in the file is not already in a partially read state. 

322 self.uriObject.seek(0) 

323 

324 def _readBytes(self, start: int, stop: int) -> bytes: 

325 self.uriObject.seek(start) 

326 result = self.uriObject.read(stop - start) 

327 return result 

328 

329 

330@dataclass 

331class LoadHelper: 

332 """This is a helper class to assist with selecting the appropriate loader 

333 and managing any contexts that may be needed. 

334 

335 Note 

336 ---- 

337 This class may go away or be modified in the future if some of the 

338 features of this module can be propagated to 

339 `~lsst.resources.ResourcePath`. 

340 

341 This helper will raise a `ValueError` if the specified file does not appear 

342 to be a valid `QuantumGraph` save file. 

343 """ 

344 

345 uri: Union[ResourcePath, IO[bytes]] 

346 """ResourcePath object from which the `QuantumGraph` is to be loaded 

347 """ 

348 minimumVersion: int 

349 """ 

350 Minimum version of a save file to load. Set to -1 to load all 

351 versions. Older versions may need to be loaded, and re-saved 

352 to upgrade them to the latest format before they can be used in 

353 production. 

354 """ 

355 

356 def __enter__(self): 

357 # Only one handler is registered for anything that is an instance of 

358 # IOBase, so if any type is a subtype of that, set the key explicitly 

359 # so the correct loader is found, otherwise index by the type 

360 self._loaded = self._determineLoader()(self.uri, self.minimumVersion) 

361 return self._loaded 

362 

363 def __exit__(self, type, value, traceback): 

364 self._loaded.close() 

365 

366 def _determineLoader(self) -> Type[DefaultLoadHelper]: 

367 if isinstance(self.uri, io.IOBase): 

368 key = io.IOBase 

369 else: 

370 key = type(self.uri) 

371 return HELPER_REGISTRY[key] 

372 

373 def readHeader(self) -> Optional[str]: 

374 type_ = self._determineLoader() 

375 return type_.dumpHeader(self.uri, self.minimumVersion)