Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 40%

130 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-07-14 16:10 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("LoadHelper",) 

24 

25import functools 

26import io 

27import struct 

28from dataclasses import dataclass 

29from types import TracebackType 

30from typing import ( 

31 TYPE_CHECKING, 

32 Any, 

33 BinaryIO, 

34 Callable, 

35 ContextManager, 

36 Iterable, 

37 Optional, 

38 Set, 

39 Type, 

40 TypeVar, 

41 Union, 

42) 

43from uuid import UUID 

44 

45from lsst.daf.butler import DimensionUniverse 

46from lsst.resources import ResourcePath 

47from lsst.resources.file import FileResourcePath 

48from lsst.resources.s3 import S3ResourcePath 

49 

50if TYPE_CHECKING: 50 ↛ 51line 50 didn't jump to line 51, because the condition on line 50 was never true

51 from ._versionDeserializers import DeserializerBase 

52 from .graph import QuantumGraph 

53 

54 

55_T = TypeVar("_T") 

56 

57 

58# Create a custom dict that will return the desired default if a key is missing 

59class RegistryDict(dict): 

60 def __missing__(self, key: Any) -> Type[DefaultLoadHelper]: 

61 return DefaultLoadHelper 

62 

63 

64# Create a registry to hold all the load Helper classes 

65HELPER_REGISTRY = RegistryDict() 

66 

67 

68def register_helper(URIClass: Union[Type[ResourcePath], Type[BinaryIO]]) -> Callable[[_T], _T]: 

69 """Used to register classes as Load helpers 

70 

71 When decorating a class the parameter is the class of "handle type", i.e. 

72 a ResourcePath type or open file handle that will be used to do the 

73 loading. This is then associated with the decorated class such that when 

74 the parameter type is used to load data, the appropriate helper to work 

75 with that data type can be returned. 

76 

77 A decorator is used so that in theory someone could define another handler 

78 in a different module and register it for use. 

79 

80 Parameters 

81 ---------- 

82 URIClass : Type of `~lsst.resources.ResourcePath` or `~IO` of bytes 

83 type for which the decorated class should be mapped to 

84 """ 

85 

86 def wrapper(class_: _T) -> _T: 

87 HELPER_REGISTRY[URIClass] = class_ 

88 return class_ 

89 

90 return wrapper 

91 

92 

93class DefaultLoadHelper: 

94 """Default load helper for `QuantumGraph` save files 

95 

96 This class, and its subclasses, are used to unpack a quantum graph save 

97 file. This file is a binary representation of the graph in a format that 

98 allows individual nodes to be loaded without needing to load the entire 

99 file. 

100 

101 This default implementation has the interface to load select nodes 

102 from disk, but actually always loads the entire save file and simply 

103 returns what nodes (or all) are requested. This is intended to serve for 

104 all cases where there is a read method on the input parameter, but it is 

105 unknown how to read select bytes of the stream. It is the responsibility of 

106 sub classes to implement the method responsible for loading individual 

107 bytes from the stream. 

108 

109 Parameters 

110 ---------- 

111 uriObject : `~lsst.resources.ResourcePath` or `IO` of bytes 

112 This is the object that will be used to retrieve the raw bytes of the 

113 save. 

114 minimumVersion : `int` 

115 Minimum version of a save file to load. Set to -1 to load all 

116 versions. Older versions may need to be loaded, and re-saved 

117 to upgrade them to the latest format. This upgrade may not happen 

118 deterministically each time an older graph format is loaded. Because 

119 of this behavior, the minimumVersion parameter, forces a user to 

120 interact manually and take this into account before they can be used in 

121 production. 

122 

123 Raises 

124 ------ 

125 ValueError 

126 Raised if the specified file contains the wrong file signature and is 

127 not a `QuantumGraph` save, or if the graph save version is below the 

128 minimum specified version. 

129 """ 

130 

131 def __init__(self, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int): 

132 headerBytes = self.__setup_impl(uriObject, minimumVersion) 

133 self.headerInfo = self.deserializer.readHeaderInfo(headerBytes) 

134 

135 def __setup_impl(self, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int) -> bytes: 

136 self.uriObject = uriObject 

137 # need to import here to avoid cyclic imports 

138 from ._versionDeserializers import DESERIALIZER_MAP 

139 from .graph import MAGIC_BYTES, SAVE_VERSION, STRUCT_FMT_BASE 

140 

141 # Read the first few bytes which correspond to the magic identifier 

142 # bytes, and save version 

143 magicSize = len(MAGIC_BYTES) 

144 # read in just the fmt base to determine the save version 

145 fmtSize = struct.calcsize(STRUCT_FMT_BASE) 

146 preambleSize = magicSize + fmtSize 

147 headerBytes = self._readBytes(0, preambleSize) 

148 magic = headerBytes[:magicSize] 

149 versionBytes = headerBytes[magicSize:] 

150 

151 if magic != MAGIC_BYTES: 

152 raise ValueError( 

153 "This file does not appear to be a quantum graph save got magic bytes " 

154 f"{magic!r}, expected {MAGIC_BYTES!r}" 

155 ) 

156 

157 # unpack the save version bytes and verify it is a version that this 

158 # code can understand 

159 (save_version,) = struct.unpack(STRUCT_FMT_BASE, versionBytes) 

160 # loads can sometimes trigger upgrades in format to a latest version, 

161 # in which case accessory code might not match the upgraded graph. 

162 # I.E. switching from old node number to UUID. This clause necessitates 

163 # that users specifically interact with older graph versions and verify 

164 # everything happens appropriately. 

165 if save_version < minimumVersion: 

166 raise ValueError( 

167 f"The loaded QuantumGraph is version {save_version}, and the minimum " 

168 f"version specified is {minimumVersion}. Please re-run this method " 

169 "with a lower minimum version, then re-save the graph to automatically upgrade" 

170 "to the newest version. Older versions may not work correctly with newer code" 

171 ) 

172 

173 if save_version > SAVE_VERSION: 

174 raise RuntimeError( 

175 f"The version of this save file is {save_version}, but this version of" 

176 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}" 

177 ) 

178 

179 # select the appropriate deserializer for this save version 

180 deserializerClass = DESERIALIZER_MAP[save_version] 

181 

182 # read in the bytes corresponding to the mappings and initialize the 

183 # deserializer. This will be the bytes that describe the following 

184 # byte boundaries of the header info 

185 sizeBytes = self._readBytes(preambleSize, preambleSize + deserializerClass.structSize) 

186 # DeserializerBase subclasses are required to have the same constructor 

187 # signature as the base class itself, but there is no way to express 

188 # this in the type system, so we just tell MyPy to ignore it. 

189 self.deserializer: DeserializerBase = deserializerClass(preambleSize, sizeBytes) # type: ignore 

190 

191 # get the header info 

192 headerBytes = self._readBytes( 

193 preambleSize + deserializerClass.structSize, self.deserializer.headerSize 

194 ) 

195 return headerBytes 

196 

197 @classmethod 

198 def dumpHeader(cls, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int = 3) -> Optional[str]: 

199 instance = cls.__new__(cls) 

200 headerBytes = instance.__setup_impl(uriObject, minimumVersion) 

201 header = instance.deserializer.unpackHeader(headerBytes) 

202 instance.close() 

203 return header 

204 

205 def load( 

206 self, 

207 universe: Optional[DimensionUniverse] = None, 

208 nodes: Optional[Iterable[Union[UUID, str]]] = None, 

209 graphID: Optional[str] = None, 

210 ) -> QuantumGraph: 

211 """Loads in the specified nodes from the graph 

212 

213 Load in the `QuantumGraph` containing only the nodes specified in the 

214 ``nodes`` parameter from the graph specified at object creation. If 

215 ``nodes`` is None (the default) the whole graph is loaded. 

216 

217 Parameters 

218 ---------- 

219 universe: `~lsst.daf.butler.DimensionUniverse` or None 

220 DimensionUniverse instance, not used by the method itself but 

221 needed to ensure that registry data structures are initialized. 

222 The universe saved with the graph is used, but if one is passed 

223 it will be used to validate the compatibility with the loaded 

224 graph universe. 

225 nodes : `Iterable` of `UUID` or `str`; or `None` 

226 The nodes to load from the graph, loads all if value is None 

227 (the default) 

228 graphID : `str` or `None` 

229 If specified this ID is verified against the loaded graph prior to 

230 loading any Nodes. This defaults to None in which case no 

231 validation is done. 

232 

233 Returns 

234 ------- 

235 graph : `QuantumGraph` 

236 The loaded `QuantumGraph` object 

237 

238 Raises 

239 ------ 

240 ValueError 

241 Raised if one or more of the nodes requested is not in the 

242 `QuantumGraph` or if graphID parameter does not match the graph 

243 being loaded. 

244 RuntimeError 

245 Raise if Supplied DimensionUniverse is not compatible with the 

246 DimensionUniverse saved in the graph 

247 """ 

248 # verify this is the expected graph 

249 if graphID is not None and self.headerInfo._buildId != graphID: 

250 raise ValueError("graphID does not match that of the graph being loaded") 

251 # Read in specified nodes, or all the nodes 

252 nodeSet: Set[UUID] 

253 if nodes is None: 

254 nodeSet = set(self.headerInfo.map.keys()) 

255 # if all nodes are to be read, force the reader from the base class 

256 # that will read all they bytes in one go 

257 _readBytes: Callable[[int, int], bytes] = functools.partial(DefaultLoadHelper._readBytes, self) 

258 else: 

259 # only some bytes are being read using the reader specialized for 

260 # this class 

261 # create a set to ensure nodes are only loaded once 

262 nodeSet = {UUID(n) if isinstance(n, str) else n for n in nodes} 

263 # verify that all nodes requested are in the graph 

264 remainder = nodeSet - self.headerInfo.map.keys() 

265 if remainder: 

266 raise ValueError( 

267 f"Nodes {remainder} were requested, but could not be found in the input graph" 

268 ) 

269 _readBytes = self._readBytes 

270 return self.deserializer.constructGraph(nodeSet, _readBytes, universe) 

271 

272 def _readBytes(self, start: int, stop: int) -> bytes: 

273 """Loads the specified byte range from the ResourcePath object 

274 

275 In the base class, this actually will read all the bytes into a buffer 

276 from the specified ResourcePath object. Then for each method call will 

277 return the requested byte range. This is the most flexible 

278 implementation, as no special read is required. This will not give a 

279 speed up with any sub graph reads though. 

280 """ 

281 if not hasattr(self, "buffer"): 

282 self.buffer = self.uriObject.read() 

283 return self.buffer[start:stop] 

284 

285 def close(self) -> None: 

286 """Cleans up an instance if needed. Base class does nothing""" 

287 pass 

288 

289 

290@register_helper(S3ResourcePath) 

291class S3LoadHelper(DefaultLoadHelper): 

292 # This subclass implements partial loading of a graph using a s3 uri 

293 def _readBytes(self, start: int, stop: int) -> bytes: 

294 args = {} 

295 # minus 1 in the stop range, because this header is inclusive rather 

296 # than standard python where the end point is generally exclusive 

297 args["Range"] = f"bytes={start}-{stop-1}" 

298 try: 

299 response = self.uriObject.client.get_object( 

300 Bucket=self.uriObject.netloc, Key=self.uriObject.relativeToPathRoot, **args 

301 ) 

302 except ( 

303 self.uriObject.client.exceptions.NoSuchKey, 

304 self.uriObject.client.exceptions.NoSuchBucket, 

305 ) as err: 

306 raise FileNotFoundError(f"No such resource: {self.uriObject}") from err 

307 body = response["Body"].read() 

308 response["Body"].close() 

309 return body 

310 

311 uriObject: S3ResourcePath 

312 

313 

314@register_helper(FileResourcePath) 

315class FileLoadHelper(DefaultLoadHelper): 

316 # This subclass implements partial loading of a graph using a file uri 

317 def _readBytes(self, start: int, stop: int) -> bytes: 

318 if not hasattr(self, "fileHandle"): 

319 self.fileHandle = open(self.uriObject.ospath, "rb") 

320 self.fileHandle.seek(start) 

321 return self.fileHandle.read(stop - start) 

322 

323 def close(self) -> None: 

324 if hasattr(self, "fileHandle"): 

325 self.fileHandle.close() 

326 

327 uriObject: FileResourcePath 

328 

329 

330@register_helper(BinaryIO) 

331class OpenFileHandleHelper(DefaultLoadHelper): 

332 # This handler is special in that it does not get initialized with a 

333 # ResourcePath, but an open file handle. 

334 

335 # Most everything stays the same, the variable is even stored as uriObject, 

336 # because the interface needed for reading is the same. Unfortunately 

337 # because we do not have Protocols yet, this can not be nicely expressed 

338 # with typing. 

339 

340 # This helper does support partial loading 

341 

342 def __init__(self, uriObject: BinaryIO, minimumVersion: int): 

343 # Explicitly annotate type and not infer from super 

344 self.uriObject: BinaryIO 

345 super().__init__(uriObject, minimumVersion=minimumVersion) 

346 # This differs from the default __init__ to force the io object 

347 # back to the beginning so that in the case the entire file is to 

348 # read in the file is not already in a partially read state. 

349 self.uriObject.seek(0) 

350 

351 def _readBytes(self, start: int, stop: int) -> bytes: 

352 self.uriObject.seek(start) 

353 result = self.uriObject.read(stop - start) 

354 return result 

355 

356 

357@dataclass 

358class LoadHelper(ContextManager[DefaultLoadHelper]): 

359 """This is a helper class to assist with selecting the appropriate loader 

360 and managing any contexts that may be needed. 

361 

362 Note 

363 ---- 

364 This class may go away or be modified in the future if some of the 

365 features of this module can be propagated to 

366 `~lsst.resources.ResourcePath`. 

367 

368 This helper will raise a `ValueError` if the specified file does not appear 

369 to be a valid `QuantumGraph` save file. 

370 """ 

371 

372 uri: Union[ResourcePath, BinaryIO] 

373 """ResourcePath object from which the `QuantumGraph` is to be loaded 

374 """ 

375 minimumVersion: int 

376 """ 

377 Minimum version of a save file to load. Set to -1 to load all 

378 versions. Older versions may need to be loaded, and re-saved 

379 to upgrade them to the latest format before they can be used in 

380 production. 

381 """ 

382 

383 def __enter__(self) -> DefaultLoadHelper: 

384 # Only one handler is registered for anything that is an instance of 

385 # IOBase, so if any type is a subtype of that, set the key explicitly 

386 # so the correct loader is found, otherwise index by the type 

387 self._loaded = self._determineLoader()(self.uri, self.minimumVersion) 

388 return self._loaded 

389 

390 def __exit__( 

391 self, 

392 type: Optional[Type[BaseException]], 

393 value: Optional[BaseException], 

394 traceback: Optional[TracebackType], 

395 ) -> None: 

396 self._loaded.close() 

397 

398 def _determineLoader(self) -> Type[DefaultLoadHelper]: 

399 key: Union[Type[ResourcePath], Type[BinaryIO]] 

400 # Typing for file-like types is a mess; BinaryIO isn't actually 

401 # a base class of what open(..., 'rb') returns, and MyPy claims 

402 # that IOBase and BinaryIO actually have incompatible method 

403 # signatures. IOBase *is* a base class of what open(..., 'rb') 

404 # returns, so it's what we have to use at runtime. 

405 if isinstance(self.uri, io.IOBase): # type: ignore 

406 key = BinaryIO 

407 else: 

408 key = type(self.uri) 

409 return HELPER_REGISTRY[key] 

410 

411 def readHeader(self) -> Optional[str]: 

412 type_ = self._determineLoader() 

413 return type_.dumpHeader(self.uri, self.minimumVersion)