Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 39%

132 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-07-27 01:58 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("LoadHelper",) 

24 

25import functools 

26import io 

27import struct 

28from dataclasses import dataclass 

29from types import TracebackType 

30from typing import ( 

31 TYPE_CHECKING, 

32 Any, 

33 BinaryIO, 

34 Callable, 

35 ContextManager, 

36 Iterable, 

37 Optional, 

38 Set, 

39 Type, 

40 TypeVar, 

41 Union, 

42) 

43from uuid import UUID 

44 

45from lsst.daf.butler import DimensionUniverse 

46from lsst.resources import ResourcePath 

47from lsst.resources.file import FileResourcePath 

48from lsst.resources.s3 import S3ResourcePath 

49 

50if TYPE_CHECKING: 50 ↛ 51line 50 didn't jump to line 51, because the condition on line 50 was never true

51 from ._versionDeserializers import DeserializerBase 

52 from .graph import QuantumGraph 

53 

54 

55_T = TypeVar("_T") 

56 

57 

58# Create a custom dict that will return the desired default if a key is missing 

59class RegistryDict(dict): 

60 def __missing__(self, key: Any) -> Type[DefaultLoadHelper]: 

61 return DefaultLoadHelper 

62 

63 

64# Create a registry to hold all the load Helper classes 

65HELPER_REGISTRY = RegistryDict() 

66 

67 

68def register_helper(URIClass: Union[Type[ResourcePath], Type[BinaryIO]]) -> Callable[[_T], _T]: 

69 """Used to register classes as Load helpers 

70 

71 When decorating a class the parameter is the class of "handle type", i.e. 

72 a ResourcePath type or open file handle that will be used to do the 

73 loading. This is then associated with the decorated class such that when 

74 the parameter type is used to load data, the appropriate helper to work 

75 with that data type can be returned. 

76 

77 A decorator is used so that in theory someone could define another handler 

78 in a different module and register it for use. 

79 

80 Parameters 

81 ---------- 

82 URIClass : Type of `~lsst.resources.ResourcePath` or `~IO` of bytes 

83 type for which the decorated class should be mapped to 

84 """ 

85 

86 def wrapper(class_: _T) -> _T: 

87 HELPER_REGISTRY[URIClass] = class_ 

88 return class_ 

89 

90 return wrapper 

91 

92 

93class DefaultLoadHelper: 

94 """Default load helper for `QuantumGraph` save files 

95 

96 This class, and its subclasses, are used to unpack a quantum graph save 

97 file. This file is a binary representation of the graph in a format that 

98 allows individual nodes to be loaded without needing to load the entire 

99 file. 

100 

101 This default implementation has the interface to load select nodes 

102 from disk, but actually always loads the entire save file and simply 

103 returns what nodes (or all) are requested. This is intended to serve for 

104 all cases where there is a read method on the input parameter, but it is 

105 unknown how to read select bytes of the stream. It is the responsibility of 

106 sub classes to implement the method responsible for loading individual 

107 bytes from the stream. 

108 

109 Parameters 

110 ---------- 

111 uriObject : `~lsst.resources.ResourcePath` or `IO` of bytes 

112 This is the object that will be used to retrieve the raw bytes of the 

113 save. 

114 minimumVersion : `int` 

115 Minimum version of a save file to load. Set to -1 to load all 

116 versions. Older versions may need to be loaded, and re-saved 

117 to upgrade them to the latest format. This upgrade may not happen 

118 deterministically each time an older graph format is loaded. Because 

119 of this behavior, the minimumVersion parameter, forces a user to 

120 interact manually and take this into account before they can be used in 

121 production. 

122 

123 Raises 

124 ------ 

125 ValueError 

126 Raised if the specified file contains the wrong file signature and is 

127 not a `QuantumGraph` save, or if the graph save version is below the 

128 minimum specified version. 

129 """ 

130 

131 def __init__(self, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int): 

132 headerBytes = self.__setup_impl(uriObject, minimumVersion) 

133 self.headerInfo = self.deserializer.readHeaderInfo(headerBytes) 

134 

135 def __setup_impl(self, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int) -> bytes: 

136 self.uriObject = uriObject 

137 # need to import here to avoid cyclic imports 

138 from ._versionDeserializers import DESERIALIZER_MAP 

139 from .graph import MAGIC_BYTES, SAVE_VERSION, STRUCT_FMT_BASE 

140 

141 # Read the first few bytes which correspond to the magic identifier 

142 # bytes, and save version 

143 magicSize = len(MAGIC_BYTES) 

144 # read in just the fmt base to determine the save version 

145 fmtSize = struct.calcsize(STRUCT_FMT_BASE) 

146 preambleSize = magicSize + fmtSize 

147 headerBytes = self._readBytes(0, preambleSize) 

148 magic = headerBytes[:magicSize] 

149 versionBytes = headerBytes[magicSize:] 

150 

151 if magic != MAGIC_BYTES: 

152 raise ValueError( 

153 "This file does not appear to be a quantum graph save got magic bytes " 

154 f"{magic!r}, expected {MAGIC_BYTES!r}" 

155 ) 

156 

157 # unpack the save version bytes and verify it is a version that this 

158 # code can understand 

159 (save_version,) = struct.unpack(STRUCT_FMT_BASE, versionBytes) 

160 # loads can sometimes trigger upgrades in format to a latest version, 

161 # in which case accessory code might not match the upgraded graph. 

162 # I.E. switching from old node number to UUID. This clause necessitates 

163 # that users specifically interact with older graph versions and verify 

164 # everything happens appropriately. 

165 if save_version < minimumVersion: 

166 raise ValueError( 

167 f"The loaded QuantumGraph is version {save_version}, and the minimum " 

168 f"version specified is {minimumVersion}. Please re-run this method " 

169 "with a lower minimum version, then re-save the graph to automatically upgrade" 

170 "to the newest version. Older versions may not work correctly with newer code" 

171 ) 

172 

173 if save_version > SAVE_VERSION: 

174 raise RuntimeError( 

175 f"The version of this save file is {save_version}, but this version of" 

176 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}" 

177 ) 

178 

179 # select the appropriate deserializer for this save version 

180 deserializerClass = DESERIALIZER_MAP[save_version] 

181 

182 # read in the bytes corresponding to the mappings and initialize the 

183 # deserializer. This will be the bytes that describe the following 

184 # byte boundaries of the header info 

185 sizeBytes = self._readBytes(preambleSize, preambleSize + deserializerClass.structSize) 

186 # DeserializerBase subclasses are required to have the same constructor 

187 # signature as the base class itself, but there is no way to express 

188 # this in the type system, so we just tell MyPy to ignore it. 

189 self.deserializer: DeserializerBase = deserializerClass(preambleSize, sizeBytes) # type: ignore 

190 

191 # get the header info 

192 headerBytes = self._readBytes( 

193 preambleSize + deserializerClass.structSize, self.deserializer.headerSize 

194 ) 

195 return headerBytes 

196 

197 @classmethod 

198 def dumpHeader(cls, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int = 3) -> Optional[str]: 

199 instance = cls.__new__(cls) 

200 headerBytes = instance.__setup_impl(uriObject, minimumVersion) 

201 header = instance.deserializer.unpackHeader(headerBytes) 

202 instance.close() 

203 return header 

204 

205 def load( 

206 self, 

207 universe: Optional[DimensionUniverse] = None, 

208 nodes: Optional[Iterable[Union[UUID, str]]] = None, 

209 graphID: Optional[str] = None, 

210 ) -> QuantumGraph: 

211 """Loads in the specified nodes from the graph 

212 

213 Load in the `QuantumGraph` containing only the nodes specified in the 

214 ``nodes`` parameter from the graph specified at object creation. If 

215 ``nodes`` is None (the default) the whole graph is loaded. 

216 

217 Parameters 

218 ---------- 

219 universe: `~lsst.daf.butler.DimensionUniverse` or None 

220 DimensionUniverse instance, not used by the method itself but 

221 needed to ensure that registry data structures are initialized. 

222 The universe saved with the graph is used, but if one is passed 

223 it will be used to validate the compatibility with the loaded 

224 graph universe. 

225 nodes : `Iterable` of `UUID` or `str`; or `None` 

226 The nodes to load from the graph, loads all if value is None 

227 (the default) 

228 graphID : `str` or `None` 

229 If specified this ID is verified against the loaded graph prior to 

230 loading any Nodes. This defaults to None in which case no 

231 validation is done. 

232 

233 Returns 

234 ------- 

235 graph : `QuantumGraph` 

236 The loaded `QuantumGraph` object 

237 

238 Raises 

239 ------ 

240 ValueError 

241 Raised if one or more of the nodes requested is not in the 

242 `QuantumGraph` or if graphID parameter does not match the graph 

243 being loaded. 

244 RuntimeError 

245 Raise if Supplied DimensionUniverse is not compatible with the 

246 DimensionUniverse saved in the graph 

247 """ 

248 # verify this is the expected graph 

249 if graphID is not None and self.headerInfo._buildId != graphID: 

250 raise ValueError("graphID does not match that of the graph being loaded") 

251 # Read in specified nodes, or all the nodes 

252 nodeSet: Set[UUID] 

253 if nodes is None: 

254 nodeSet = set(self.headerInfo.map.keys()) 

255 # if all nodes are to be read, force the reader from the base class 

256 # that will read all they bytes in one go 

257 _readBytes: Callable[[int, int], bytes] = functools.partial(DefaultLoadHelper._readBytes, self) 

258 else: 

259 # only some bytes are being read using the reader specialized for 

260 # this class 

261 # create a set to ensure nodes are only loaded once 

262 nodeSet = {UUID(n) if isinstance(n, str) else n for n in nodes} 

263 # verify that all nodes requested are in the graph 

264 remainder = nodeSet - self.headerInfo.map.keys() 

265 if remainder: 

266 raise ValueError( 

267 f"Nodes {remainder} were requested, but could not be found in the input graph" 

268 ) 

269 _readBytes = self._readBytes 

270 if universe is None: 

271 universe = self.headerInfo.universe 

272 return self.deserializer.constructGraph(nodeSet, _readBytes, universe) 

273 

274 def _readBytes(self, start: int, stop: int) -> bytes: 

275 """Loads the specified byte range from the ResourcePath object 

276 

277 In the base class, this actually will read all the bytes into a buffer 

278 from the specified ResourcePath object. Then for each method call will 

279 return the requested byte range. This is the most flexible 

280 implementation, as no special read is required. This will not give a 

281 speed up with any sub graph reads though. 

282 """ 

283 if not hasattr(self, "buffer"): 

284 self.buffer = self.uriObject.read() 

285 return self.buffer[start:stop] 

286 

287 def close(self) -> None: 

288 """Cleans up an instance if needed. Base class does nothing""" 

289 pass 

290 

291 

292@register_helper(S3ResourcePath) 

293class S3LoadHelper(DefaultLoadHelper): 

294 # This subclass implements partial loading of a graph using a s3 uri 

295 def _readBytes(self, start: int, stop: int) -> bytes: 

296 args = {} 

297 # minus 1 in the stop range, because this header is inclusive rather 

298 # than standard python where the end point is generally exclusive 

299 args["Range"] = f"bytes={start}-{stop-1}" 

300 try: 

301 response = self.uriObject.client.get_object( 

302 Bucket=self.uriObject.netloc, Key=self.uriObject.relativeToPathRoot, **args 

303 ) 

304 except ( 

305 self.uriObject.client.exceptions.NoSuchKey, 

306 self.uriObject.client.exceptions.NoSuchBucket, 

307 ) as err: 

308 raise FileNotFoundError(f"No such resource: {self.uriObject}") from err 

309 body = response["Body"].read() 

310 response["Body"].close() 

311 return body 

312 

313 uriObject: S3ResourcePath 

314 

315 

316@register_helper(FileResourcePath) 

317class FileLoadHelper(DefaultLoadHelper): 

318 # This subclass implements partial loading of a graph using a file uri 

319 def _readBytes(self, start: int, stop: int) -> bytes: 

320 if not hasattr(self, "fileHandle"): 

321 self.fileHandle = open(self.uriObject.ospath, "rb") 

322 self.fileHandle.seek(start) 

323 return self.fileHandle.read(stop - start) 

324 

325 def close(self) -> None: 

326 if hasattr(self, "fileHandle"): 

327 self.fileHandle.close() 

328 

329 uriObject: FileResourcePath 

330 

331 

332@register_helper(BinaryIO) 

333class OpenFileHandleHelper(DefaultLoadHelper): 

334 # This handler is special in that it does not get initialized with a 

335 # ResourcePath, but an open file handle. 

336 

337 # Most everything stays the same, the variable is even stored as uriObject, 

338 # because the interface needed for reading is the same. Unfortunately 

339 # because we do not have Protocols yet, this can not be nicely expressed 

340 # with typing. 

341 

342 # This helper does support partial loading 

343 

344 def __init__(self, uriObject: BinaryIO, minimumVersion: int): 

345 # Explicitly annotate type and not infer from super 

346 self.uriObject: BinaryIO 

347 super().__init__(uriObject, minimumVersion=minimumVersion) 

348 # This differs from the default __init__ to force the io object 

349 # back to the beginning so that in the case the entire file is to 

350 # read in the file is not already in a partially read state. 

351 self.uriObject.seek(0) 

352 

353 def _readBytes(self, start: int, stop: int) -> bytes: 

354 self.uriObject.seek(start) 

355 result = self.uriObject.read(stop - start) 

356 return result 

357 

358 

359@dataclass 

360class LoadHelper(ContextManager[DefaultLoadHelper]): 

361 """This is a helper class to assist with selecting the appropriate loader 

362 and managing any contexts that may be needed. 

363 

364 Note 

365 ---- 

366 This class may go away or be modified in the future if some of the 

367 features of this module can be propagated to 

368 `~lsst.resources.ResourcePath`. 

369 

370 This helper will raise a `ValueError` if the specified file does not appear 

371 to be a valid `QuantumGraph` save file. 

372 """ 

373 

374 uri: Union[ResourcePath, BinaryIO] 

375 """ResourcePath object from which the `QuantumGraph` is to be loaded 

376 """ 

377 minimumVersion: int 

378 """ 

379 Minimum version of a save file to load. Set to -1 to load all 

380 versions. Older versions may need to be loaded, and re-saved 

381 to upgrade them to the latest format before they can be used in 

382 production. 

383 """ 

384 

385 def __enter__(self) -> DefaultLoadHelper: 

386 # Only one handler is registered for anything that is an instance of 

387 # IOBase, so if any type is a subtype of that, set the key explicitly 

388 # so the correct loader is found, otherwise index by the type 

389 self._loaded = self._determineLoader()(self.uri, self.minimumVersion) 

390 return self._loaded 

391 

392 def __exit__( 

393 self, 

394 type: Optional[Type[BaseException]], 

395 value: Optional[BaseException], 

396 traceback: Optional[TracebackType], 

397 ) -> None: 

398 self._loaded.close() 

399 

400 def _determineLoader(self) -> Type[DefaultLoadHelper]: 

401 key: Union[Type[ResourcePath], Type[BinaryIO]] 

402 # Typing for file-like types is a mess; BinaryIO isn't actually 

403 # a base class of what open(..., 'rb') returns, and MyPy claims 

404 # that IOBase and BinaryIO actually have incompatible method 

405 # signatures. IOBase *is* a base class of what open(..., 'rb') 

406 # returns, so it's what we have to use at runtime. 

407 if isinstance(self.uri, io.IOBase): # type: ignore 

408 key = BinaryIO 

409 else: 

410 key = type(self.uri) 

411 return HELPER_REGISTRY[key] 

412 

413 def readHeader(self) -> Optional[str]: 

414 type_ = self._determineLoader() 

415 return type_.dumpHeader(self.uri, self.minimumVersion)