Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 40%

130 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-09 02:55 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("LoadHelper",) 

24 

25import functools 

26import io 

27import struct 

28from dataclasses import dataclass 

29from types import TracebackType 

30from typing import ( 

31 TYPE_CHECKING, 

32 Any, 

33 BinaryIO, 

34 Callable, 

35 ContextManager, 

36 Iterable, 

37 Optional, 

38 Set, 

39 Type, 

40 TypeVar, 

41 Union, 

42) 

43from uuid import UUID 

44 

45from lsst.daf.butler import DimensionUniverse 

46from lsst.resources import ResourcePath 

47from lsst.resources.file import FileResourcePath 

48from lsst.resources.s3 import S3ResourcePath 

49 

50if TYPE_CHECKING: 50 ↛ 51line 50 didn't jump to line 51, because the condition on line 50 was never true

51 from ._versionDeserializers import DeserializerBase 

52 from .graph import QuantumGraph 

53 

54 

55_T = TypeVar("_T") 

56 

57 

58# Create a custom dict that will return the desired default if a key is missing 

59class RegistryDict(dict): 

60 def __missing__(self, key: Any) -> Type[DefaultLoadHelper]: 

61 return DefaultLoadHelper 

62 

63 

64# Create a registry to hold all the load Helper classes 

65HELPER_REGISTRY = RegistryDict() 

66 

67 

68def register_helper(URIClass: Union[Type[ResourcePath], Type[BinaryIO]]) -> Callable[[_T], _T]: 

69 """Used to register classes as Load helpers 

70 

71 When decorating a class the parameter is the class of "handle type", i.e. 

72 a ResourcePath type or open file handle that will be used to do the 

73 loading. This is then associated with the decorated class such that when 

74 the parameter type is used to load data, the appropriate helper to work 

75 with that data type can be returned. 

76 

77 A decorator is used so that in theory someone could define another handler 

78 in a different module and register it for use. 

79 

80 Parameters 

81 ---------- 

82 URIClass : Type of `~lsst.resources.ResourcePath` or `~IO` of bytes 

83 type for which the decorated class should be mapped to 

84 """ 

85 

86 def wrapper(class_: _T) -> _T: 

87 HELPER_REGISTRY[URIClass] = class_ 

88 return class_ 

89 

90 return wrapper 

91 

92 

93class DefaultLoadHelper: 

94 """Default load helper for `QuantumGraph` save files 

95 

96 This class, and its subclasses, are used to unpack a quantum graph save 

97 file. This file is a binary representation of the graph in a format that 

98 allows individual nodes to be loaded without needing to load the entire 

99 file. 

100 

101 This default implementation has the interface to load select nodes 

102 from disk, but actually always loads the entire save file and simply 

103 returns what nodes (or all) are requested. This is intended to serve for 

104 all cases where there is a read method on the input parameter, but it is 

105 unknown how to read select bytes of the stream. It is the responsibility of 

106 sub classes to implement the method responsible for loading individual 

107 bytes from the stream. 

108 

109 Parameters 

110 ---------- 

111 uriObject : `~lsst.resources.ResourcePath` or `IO` of bytes 

112 This is the object that will be used to retrieve the raw bytes of the 

113 save. 

114 minimumVersion : `int` 

115 Minimum version of a save file to load. Set to -1 to load all 

116 versions. Older versions may need to be loaded, and re-saved 

117 to upgrade them to the latest format. This upgrade may not happen 

118 deterministically each time an older graph format is loaded. Because 

119 of this behavior, the minimumVersion parameter, forces a user to 

120 interact manually and take this into account before they can be used in 

121 production. 

122 

123 Raises 

124 ------ 

125 ValueError 

126 Raised if the specified file contains the wrong file signature and is 

127 not a `QuantumGraph` save, or if the graph save version is below the 

128 minimum specified version. 

129 """ 

130 

131 def __init__(self, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int): 

132 headerBytes = self.__setup_impl(uriObject, minimumVersion) 

133 self.headerInfo = self.deserializer.readHeaderInfo(headerBytes) 

134 

135 def __setup_impl(self, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int) -> bytes: 

136 self.uriObject = uriObject 

137 # need to import here to avoid cyclic imports 

138 from ._versionDeserializers import DESERIALIZER_MAP 

139 from .graph import MAGIC_BYTES, SAVE_VERSION, STRUCT_FMT_BASE 

140 

141 # Read the first few bytes which correspond to the magic identifier 

142 # bytes, and save version 

143 magicSize = len(MAGIC_BYTES) 

144 # read in just the fmt base to determine the save version 

145 fmtSize = struct.calcsize(STRUCT_FMT_BASE) 

146 preambleSize = magicSize + fmtSize 

147 headerBytes = self._readBytes(0, preambleSize) 

148 magic = headerBytes[:magicSize] 

149 versionBytes = headerBytes[magicSize:] 

150 

151 if magic != MAGIC_BYTES: 

152 raise ValueError( 

153 "This file does not appear to be a quantum graph save got magic bytes " 

154 f"{magic!r}, expected {MAGIC_BYTES!r}" 

155 ) 

156 

157 # unpack the save version bytes and verify it is a version that this 

158 # code can understand 

159 (save_version,) = struct.unpack(STRUCT_FMT_BASE, versionBytes) 

160 # loads can sometimes trigger upgrades in format to a latest version, 

161 # in which case accessory code might not match the upgraded graph. 

162 # I.E. switching from old node number to UUID. This clause necessitates 

163 # that users specifically interact with older graph versions and verify 

164 # everything happens appropriately. 

165 if save_version < minimumVersion: 

166 raise ValueError( 

167 f"The loaded QuantumGraph is version {save_version}, and the minimum " 

168 f"version specified is {minimumVersion}. Please re-run this method " 

169 "with a lower minimum version, then re-save the graph to automatically upgrade" 

170 "to the newest version. Older versions may not work correctly with newer code" 

171 ) 

172 

173 if save_version > SAVE_VERSION: 

174 raise RuntimeError( 

175 f"The version of this save file is {save_version}, but this version of" 

176 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}" 

177 ) 

178 

179 # select the appropriate deserializer for this save version 

180 deserializerClass = DESERIALIZER_MAP[save_version] 

181 

182 # read in the bytes corresponding to the mappings and initialize the 

183 # deserializer. This will be the bytes that describe the following 

184 # byte boundaries of the header info 

185 sizeBytes = self._readBytes(preambleSize, preambleSize + deserializerClass.structSize) 

186 # DeserializerBase subclasses are required to have the same constructor 

187 # signature as the base class itself, but there is no way to express 

188 # this in the type system, so we just tell MyPy to ignore it. 

189 self.deserializer: DeserializerBase = deserializerClass(preambleSize, sizeBytes) # type: ignore 

190 

191 # get the header info 

192 headerBytes = self._readBytes( 

193 preambleSize + deserializerClass.structSize, self.deserializer.headerSize 

194 ) 

195 return headerBytes 

196 

197 @classmethod 

198 def dumpHeader(cls, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int = 3) -> Optional[str]: 

199 instance = cls.__new__(cls) 

200 headerBytes = instance.__setup_impl(uriObject, minimumVersion) 

201 header = instance.deserializer.unpackHeader(headerBytes) 

202 instance.close() 

203 return header 

204 

205 def load( 

206 self, 

207 universe: DimensionUniverse, 

208 nodes: Optional[Iterable[Union[UUID, str]]] = None, 

209 graphID: Optional[str] = None, 

210 ) -> QuantumGraph: 

211 """Loads in the specified nodes from the graph 

212 

213 Load in the `QuantumGraph` containing only the nodes specified in the 

214 ``nodes`` parameter from the graph specified at object creation. If 

215 ``nodes`` is None (the default) the whole graph is loaded. 

216 

217 Parameters 

218 ---------- 

219 universe: `~lsst.daf.butler.DimensionUniverse` 

220 DimensionUniverse instance, not used by the method itself but 

221 needed to ensure that registry data structures are initialized. 

222 nodes : `Iterable` of `UUID` or `str`; or `None` 

223 The nodes to load from the graph, loads all if value is None 

224 (the default) 

225 graphID : `str` or `None` 

226 If specified this ID is verified against the loaded graph prior to 

227 loading any Nodes. This defaults to None in which case no 

228 validation is done. 

229 

230 Returns 

231 ------- 

232 graph : `QuantumGraph` 

233 The loaded `QuantumGraph` object 

234 

235 Raises 

236 ------ 

237 ValueError 

238 Raised if one or more of the nodes requested is not in the 

239 `QuantumGraph` or if graphID parameter does not match the graph 

240 being loaded. 

241 """ 

242 # verify this is the expected graph 

243 if graphID is not None and self.headerInfo._buildId != graphID: 

244 raise ValueError("graphID does not match that of the graph being loaded") 

245 # Read in specified nodes, or all the nodes 

246 nodeSet: Set[UUID] 

247 if nodes is None: 

248 nodeSet = set(self.headerInfo.map.keys()) 

249 # if all nodes are to be read, force the reader from the base class 

250 # that will read all they bytes in one go 

251 _readBytes: Callable[[int, int], bytes] = functools.partial(DefaultLoadHelper._readBytes, self) 

252 else: 

253 # only some bytes are being read using the reader specialized for 

254 # this class 

255 # create a set to ensure nodes are only loaded once 

256 nodeSet = {UUID(n) if isinstance(n, str) else n for n in nodes} 

257 # verify that all nodes requested are in the graph 

258 remainder = nodeSet - self.headerInfo.map.keys() 

259 if remainder: 

260 raise ValueError( 

261 f"Nodes {remainder} were requested, but could not be found in the input graph" 

262 ) 

263 _readBytes = self._readBytes 

264 return self.deserializer.constructGraph(nodeSet, _readBytes, universe) 

265 

266 def _readBytes(self, start: int, stop: int) -> bytes: 

267 """Loads the specified byte range from the ResourcePath object 

268 

269 In the base class, this actually will read all the bytes into a buffer 

270 from the specified ResourcePath object. Then for each method call will 

271 return the requested byte range. This is the most flexible 

272 implementation, as no special read is required. This will not give a 

273 speed up with any sub graph reads though. 

274 """ 

275 if not hasattr(self, "buffer"): 

276 self.buffer = self.uriObject.read() 

277 return self.buffer[start:stop] 

278 

279 def close(self) -> None: 

280 """Cleans up an instance if needed. Base class does nothing""" 

281 pass 

282 

283 

284@register_helper(S3ResourcePath) 

285class S3LoadHelper(DefaultLoadHelper): 

286 # This subclass implements partial loading of a graph using a s3 uri 

287 def _readBytes(self, start: int, stop: int) -> bytes: 

288 args = {} 

289 # minus 1 in the stop range, because this header is inclusive rather 

290 # than standard python where the end point is generally exclusive 

291 args["Range"] = f"bytes={start}-{stop-1}" 

292 try: 

293 response = self.uriObject.client.get_object( 

294 Bucket=self.uriObject.netloc, Key=self.uriObject.relativeToPathRoot, **args 

295 ) 

296 except ( 

297 self.uriObject.client.exceptions.NoSuchKey, 

298 self.uriObject.client.exceptions.NoSuchBucket, 

299 ) as err: 

300 raise FileNotFoundError(f"No such resource: {self.uriObject}") from err 

301 body = response["Body"].read() 

302 response["Body"].close() 

303 return body 

304 

305 uriObject: S3ResourcePath 

306 

307 

308@register_helper(FileResourcePath) 

309class FileLoadHelper(DefaultLoadHelper): 

310 # This subclass implements partial loading of a graph using a file uri 

311 def _readBytes(self, start: int, stop: int) -> bytes: 

312 if not hasattr(self, "fileHandle"): 

313 self.fileHandle = open(self.uriObject.ospath, "rb") 

314 self.fileHandle.seek(start) 

315 return self.fileHandle.read(stop - start) 

316 

317 def close(self) -> None: 

318 if hasattr(self, "fileHandle"): 

319 self.fileHandle.close() 

320 

321 uriObject: FileResourcePath 

322 

323 

324@register_helper(BinaryIO) 

325class OpenFileHandleHelper(DefaultLoadHelper): 

326 # This handler is special in that it does not get initialized with a 

327 # ResourcePath, but an open file handle. 

328 

329 # Most everything stays the same, the variable is even stored as uriObject, 

330 # because the interface needed for reading is the same. Unfortunately 

331 # because we do not have Protocols yet, this can not be nicely expressed 

332 # with typing. 

333 

334 # This helper does support partial loading 

335 

336 def __init__(self, uriObject: BinaryIO, minimumVersion: int): 

337 # Explicitly annotate type and not infer from super 

338 self.uriObject: BinaryIO 

339 super().__init__(uriObject, minimumVersion=minimumVersion) 

340 # This differs from the default __init__ to force the io object 

341 # back to the beginning so that in the case the entire file is to 

342 # read in the file is not already in a partially read state. 

343 self.uriObject.seek(0) 

344 

345 def _readBytes(self, start: int, stop: int) -> bytes: 

346 self.uriObject.seek(start) 

347 result = self.uriObject.read(stop - start) 

348 return result 

349 

350 

351@dataclass 

352class LoadHelper(ContextManager[DefaultLoadHelper]): 

353 """This is a helper class to assist with selecting the appropriate loader 

354 and managing any contexts that may be needed. 

355 

356 Note 

357 ---- 

358 This class may go away or be modified in the future if some of the 

359 features of this module can be propagated to 

360 `~lsst.resources.ResourcePath`. 

361 

362 This helper will raise a `ValueError` if the specified file does not appear 

363 to be a valid `QuantumGraph` save file. 

364 """ 

365 

366 uri: Union[ResourcePath, BinaryIO] 

367 """ResourcePath object from which the `QuantumGraph` is to be loaded 

368 """ 

369 minimumVersion: int 

370 """ 

371 Minimum version of a save file to load. Set to -1 to load all 

372 versions. Older versions may need to be loaded, and re-saved 

373 to upgrade them to the latest format before they can be used in 

374 production. 

375 """ 

376 

377 def __enter__(self) -> DefaultLoadHelper: 

378 # Only one handler is registered for anything that is an instance of 

379 # IOBase, so if any type is a subtype of that, set the key explicitly 

380 # so the correct loader is found, otherwise index by the type 

381 self._loaded = self._determineLoader()(self.uri, self.minimumVersion) 

382 return self._loaded 

383 

384 def __exit__( 

385 self, 

386 type: Optional[Type[BaseException]], 

387 value: Optional[BaseException], 

388 traceback: Optional[TracebackType], 

389 ) -> None: 

390 self._loaded.close() 

391 

392 def _determineLoader(self) -> Type[DefaultLoadHelper]: 

393 key: Union[Type[ResourcePath], Type[BinaryIO]] 

394 # Typing for file-like types is a mess; BinaryIO isn't actually 

395 # a base class of what open(..., 'rb') returns, and MyPy claims 

396 # that IOBase and BinaryIO actually have incompatible method 

397 # signatures. IOBase *is* a base class of what open(..., 'rb') 

398 # returns, so it's what we have to use at runtime. 

399 if isinstance(self.uri, io.IOBase): # type: ignore 

400 key = BinaryIO 

401 else: 

402 key = type(self.uri) 

403 return HELPER_REGISTRY[key] 

404 

405 def readHeader(self) -> Optional[str]: 

406 type_ = self._determineLoader() 

407 return type_.dumpHeader(self.uri, self.minimumVersion)