Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 40%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

128 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("LoadHelper",) 

24 

25import functools 

26import io 

27import struct 

28from dataclasses import dataclass 

29from types import TracebackType 

30from typing import ( 

31 TYPE_CHECKING, 

32 Any, 

33 BinaryIO, 

34 Callable, 

35 ContextManager, 

36 Iterable, 

37 Optional, 

38 Type, 

39 TypeVar, 

40 Union, 

41) 

42from uuid import UUID 

43 

44from lsst.daf.butler import DimensionUniverse 

45from lsst.resources import ResourcePath 

46from lsst.resources.file import FileResourcePath 

47from lsst.resources.s3 import S3ResourcePath 

48 

49if TYPE_CHECKING: 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true

50 from ._versionDeserializers import DeserializerBase 

51 from .graph import QuantumGraph 

52 

53 

54_T = TypeVar("_T") 

55 

56 

57# Create a custom dict that will return the desired default if a key is missing 

58class RegistryDict(dict): 

59 def __missing__(self, key: Any) -> Type[DefaultLoadHelper]: 

60 return DefaultLoadHelper 

61 

62 

63# Create a registry to hold all the load Helper classes 

64HELPER_REGISTRY = RegistryDict() 

65 

66 

67def register_helper(URIClass: Union[Type[ResourcePath], Type[BinaryIO]]) -> Callable[[_T], _T]: 

68 """Used to register classes as Load helpers 

69 

70 When decorating a class the parameter is the class of "handle type", i.e. 

71 a ResourcePath type or open file handle that will be used to do the 

72 loading. This is then associated with the decorated class such that when 

73 the parameter type is used to load data, the appropriate helper to work 

74 with that data type can be returned. 

75 

76 A decorator is used so that in theory someone could define another handler 

77 in a different module and register it for use. 

78 

79 Parameters 

80 ---------- 

81 URIClass : Type of `~lsst.resources.ResourcePath` or `~IO` of bytes 

82 type for which the decorated class should be mapped to 

83 """ 

84 

85 def wrapper(class_: _T) -> _T: 

86 HELPER_REGISTRY[URIClass] = class_ 

87 return class_ 

88 

89 return wrapper 

90 

91 

92class DefaultLoadHelper: 

93 """Default load helper for `QuantumGraph` save files 

94 

95 This class, and its subclasses, are used to unpack a quantum graph save 

96 file. This file is a binary representation of the graph in a format that 

97 allows individual nodes to be loaded without needing to load the entire 

98 file. 

99 

100 This default implementation has the interface to load select nodes 

101 from disk, but actually always loads the entire save file and simply 

102 returns what nodes (or all) are requested. This is intended to serve for 

103 all cases where there is a read method on the input parameter, but it is 

104 unknown how to read select bytes of the stream. It is the responsibility of 

105 sub classes to implement the method responsible for loading individual 

106 bytes from the stream. 

107 

108 Parameters 

109 ---------- 

110 uriObject : `~lsst.resources.ResourcePath` or `IO` of bytes 

111 This is the object that will be used to retrieve the raw bytes of the 

112 save. 

113 minimumVersion : `int` 

114 Minimum version of a save file to load. Set to -1 to load all 

115 versions. Older versions may need to be loaded, and re-saved 

116 to upgrade them to the latest format. This upgrade may not happen 

117 deterministically each time an older graph format is loaded. Because 

118 of this behavior, the minimumVersion parameter, forces a user to 

119 interact manually and take this into account before they can be used in 

120 production. 

121 

122 Raises 

123 ------ 

124 ValueError 

125 Raised if the specified file contains the wrong file signature and is 

126 not a `QuantumGraph` save, or if the graph save version is below the 

127 minimum specified version. 

128 """ 

129 

130 def __init__(self, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int): 

131 headerBytes = self.__setup_impl(uriObject, minimumVersion) 

132 self.headerInfo = self.deserializer.readHeaderInfo(headerBytes) 

133 

134 def __setup_impl(self, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int) -> bytes: 

135 self.uriObject = uriObject 

136 # need to import here to avoid cyclic imports 

137 from ._versionDeserializers import DESERIALIZER_MAP 

138 from .graph import MAGIC_BYTES, SAVE_VERSION, STRUCT_FMT_BASE 

139 

140 # Read the first few bytes which correspond to the magic identifier 

141 # bytes, and save version 

142 magicSize = len(MAGIC_BYTES) 

143 # read in just the fmt base to determine the save version 

144 fmtSize = struct.calcsize(STRUCT_FMT_BASE) 

145 preambleSize = magicSize + fmtSize 

146 headerBytes = self._readBytes(0, preambleSize) 

147 magic = headerBytes[:magicSize] 

148 versionBytes = headerBytes[magicSize:] 

149 

150 if magic != MAGIC_BYTES: 

151 raise ValueError( 

152 "This file does not appear to be a quantum graph save got magic bytes " 

153 f"{magic!r}, expected {MAGIC_BYTES!r}" 

154 ) 

155 

156 # unpack the save version bytes and verify it is a version that this 

157 # code can understand 

158 (save_version,) = struct.unpack(STRUCT_FMT_BASE, versionBytes) 

159 # loads can sometimes trigger upgrades in format to a latest version, 

160 # in which case accessory code might not match the upgraded graph. 

161 # I.E. switching from old node number to UUID. This clause necessitates 

162 # that users specifically interact with older graph versions and verify 

163 # everything happens appropriately. 

164 if save_version < minimumVersion: 

165 raise ValueError( 

166 f"The loaded QuantumGraph is version {save_version}, and the minimum " 

167 f"version specified is {minimumVersion}. Please re-run this method " 

168 "with a lower minimum version, then re-save the graph to automatically upgrade" 

169 "to the newest version. Older versions may not work correctly with newer code" 

170 ) 

171 

172 if save_version > SAVE_VERSION: 

173 raise RuntimeError( 

174 f"The version of this save file is {save_version}, but this version of" 

175 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}" 

176 ) 

177 

178 # select the appropriate deserializer for this save version 

179 deserializerClass = DESERIALIZER_MAP[save_version] 

180 

181 # read in the bytes corresponding to the mappings and initialize the 

182 # deserializer. This will be the bytes that describe the following 

183 # byte boundaries of the header info 

184 sizeBytes = self._readBytes(preambleSize, preambleSize + deserializerClass.structSize) 

185 # DeserializerBase subclasses are required to have the same constructor 

186 # signature as the base class itself, but there is no way to express 

187 # this in the type system, so we just tell MyPy to ignore it. 

188 self.deserializer: DeserializerBase = deserializerClass(preambleSize, sizeBytes) # type: ignore 

189 

190 # get the header info 

191 headerBytes = self._readBytes( 

192 preambleSize + deserializerClass.structSize, self.deserializer.headerSize 

193 ) 

194 return headerBytes 

195 

196 @classmethod 

197 def dumpHeader(cls, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int = 3) -> Optional[str]: 

198 instance = cls.__new__(cls) 

199 headerBytes = instance.__setup_impl(uriObject, minimumVersion) 

200 header = instance.deserializer.unpackHeader(headerBytes) 

201 instance.close() 

202 return header 

203 

204 def load( 

205 self, 

206 universe: DimensionUniverse, 

207 nodes: Optional[Iterable[UUID]] = None, 

208 graphID: Optional[str] = None, 

209 ) -> QuantumGraph: 

210 """Loads in the specified nodes from the graph 

211 

212 Load in the `QuantumGraph` containing only the nodes specified in the 

213 ``nodes`` parameter from the graph specified at object creation. If 

214 ``nodes`` is None (the default) the whole graph is loaded. 

215 

216 Parameters 

217 ---------- 

218 universe: `~lsst.daf.butler.DimensionUniverse` 

219 DimensionUniverse instance, not used by the method itself but 

220 needed to ensure that registry data structures are initialized. 

221 nodes : `Iterable` of `UUID` or `None` 

222 The nodes to load from the graph, loads all if value is None 

223 (the default) 

224 graphID : `str` or `None` 

225 If specified this ID is verified against the loaded graph prior to 

226 loading any Nodes. This defaults to None in which case no 

227 validation is done. 

228 

229 Returns 

230 ------- 

231 graph : `QuantumGraph` 

232 The loaded `QuantumGraph` object 

233 

234 Raises 

235 ------ 

236 ValueError 

237 Raised if one or more of the nodes requested is not in the 

238 `QuantumGraph` or if graphID parameter does not match the graph 

239 being loaded. 

240 """ 

241 # verify this is the expected graph 

242 if graphID is not None and self.headerInfo._buildId != graphID: 

243 raise ValueError("graphID does not match that of the graph being loaded") 

244 # Read in specified nodes, or all the nodes 

245 if nodes is None: 

246 nodes = set(self.headerInfo.map.keys()) 

247 # if all nodes are to be read, force the reader from the base class 

248 # that will read all they bytes in one go 

249 _readBytes: Callable[[int, int], bytes] = functools.partial(DefaultLoadHelper._readBytes, self) 

250 else: 

251 # only some bytes are being read using the reader specialized for 

252 # this class 

253 # create a set to ensure nodes are only loaded once 

254 nodes = {UUID(n) if isinstance(n, str) else n for n in nodes} 

255 # verify that all nodes requested are in the graph 

256 remainder = nodes - self.headerInfo.map.keys() 

257 if remainder: 

258 raise ValueError( 

259 f"Nodes {remainder} were requested, but could not be found in the input graph" 

260 ) 

261 _readBytes = self._readBytes 

262 return self.deserializer.constructGraph(nodes, _readBytes, universe) 

263 

264 def _readBytes(self, start: int, stop: int) -> bytes: 

265 """Loads the specified byte range from the ResourcePath object 

266 

267 In the base class, this actually will read all the bytes into a buffer 

268 from the specified ResourcePath object. Then for each method call will 

269 return the requested byte range. This is the most flexible 

270 implementation, as no special read is required. This will not give a 

271 speed up with any sub graph reads though. 

272 """ 

273 if not hasattr(self, "buffer"): 

274 self.buffer = self.uriObject.read() 

275 return self.buffer[start:stop] 

276 

277 def close(self) -> None: 

278 """Cleans up an instance if needed. Base class does nothing""" 

279 pass 

280 

281 

282@register_helper(S3ResourcePath) 

283class S3LoadHelper(DefaultLoadHelper): 

284 # This subclass implements partial loading of a graph using a s3 uri 

285 def _readBytes(self, start: int, stop: int) -> bytes: 

286 args = {} 

287 # minus 1 in the stop range, because this header is inclusive rather 

288 # than standard python where the end point is generally exclusive 

289 args["Range"] = f"bytes={start}-{stop-1}" 

290 try: 

291 response = self.uriObject.client.get_object( 

292 Bucket=self.uriObject.netloc, Key=self.uriObject.relativeToPathRoot, **args 

293 ) 

294 except ( 

295 self.uriObject.client.exceptions.NoSuchKey, 

296 self.uriObject.client.exceptions.NoSuchBucket, 

297 ) as err: 

298 raise FileNotFoundError(f"No such resource: {self.uriObject}") from err 

299 body = response["Body"].read() 

300 response["Body"].close() 

301 return body 

302 

303 uriObject: S3ResourcePath 

304 

305 

306@register_helper(FileResourcePath) 

307class FileLoadHelper(DefaultLoadHelper): 

308 # This subclass implements partial loading of a graph using a file uri 

309 def _readBytes(self, start: int, stop: int) -> bytes: 

310 if not hasattr(self, "fileHandle"): 

311 self.fileHandle = open(self.uriObject.ospath, "rb") 

312 self.fileHandle.seek(start) 

313 return self.fileHandle.read(stop - start) 

314 

315 def close(self) -> None: 

316 if hasattr(self, "fileHandle"): 

317 self.fileHandle.close() 

318 

319 uriObject: FileResourcePath 

320 

321 

322@register_helper(BinaryIO) 

323class OpenFileHandleHelper(DefaultLoadHelper): 

324 # This handler is special in that it does not get initialized with a 

325 # ResourcePath, but an open file handle. 

326 

327 # Most everything stays the same, the variable is even stored as uriObject, 

328 # because the interface needed for reading is the same. Unfortunately 

329 # because we do not have Protocols yet, this can not be nicely expressed 

330 # with typing. 

331 

332 # This helper does support partial loading 

333 

334 def __init__(self, uriObject: BinaryIO, minimumVersion: int): 

335 # Explicitly annotate type and not infer from super 

336 self.uriObject: BinaryIO 

337 super().__init__(uriObject, minimumVersion=minimumVersion) 

338 # This differs from the default __init__ to force the io object 

339 # back to the beginning so that in the case the entire file is to 

340 # read in the file is not already in a partially read state. 

341 self.uriObject.seek(0) 

342 

343 def _readBytes(self, start: int, stop: int) -> bytes: 

344 self.uriObject.seek(start) 

345 result = self.uriObject.read(stop - start) 

346 return result 

347 

348 

349@dataclass 

350class LoadHelper(ContextManager[DefaultLoadHelper]): 

351 """This is a helper class to assist with selecting the appropriate loader 

352 and managing any contexts that may be needed. 

353 

354 Note 

355 ---- 

356 This class may go away or be modified in the future if some of the 

357 features of this module can be propagated to 

358 `~lsst.resources.ResourcePath`. 

359 

360 This helper will raise a `ValueError` if the specified file does not appear 

361 to be a valid `QuantumGraph` save file. 

362 """ 

363 

364 uri: Union[ResourcePath, BinaryIO] 

365 """ResourcePath object from which the `QuantumGraph` is to be loaded 

366 """ 

367 minimumVersion: int 

368 """ 

369 Minimum version of a save file to load. Set to -1 to load all 

370 versions. Older versions may need to be loaded, and re-saved 

371 to upgrade them to the latest format before they can be used in 

372 production. 

373 """ 

374 

375 def __enter__(self) -> DefaultLoadHelper: 

376 # Only one handler is registered for anything that is an instance of 

377 # IOBase, so if any type is a subtype of that, set the key explicitly 

378 # so the correct loader is found, otherwise index by the type 

379 self._loaded = self._determineLoader()(self.uri, self.minimumVersion) 

380 return self._loaded 

381 

382 def __exit__( 

383 self, 

384 type: Optional[Type[BaseException]], 

385 value: Optional[BaseException], 

386 traceback: Optional[TracebackType], 

387 ) -> None: 

388 self._loaded.close() 

389 

390 def _determineLoader(self) -> Type[DefaultLoadHelper]: 

391 key: Union[Type[ResourcePath], Type[BinaryIO]] 

392 # Typing for file-like types is a mess; BinaryIO isn't actually 

393 # a base class of what open(..., 'rb') returns, and MyPy claims 

394 # that IOBase and BinaryIO actually have incompatible method 

395 # signatures. IOBase *is* a base class of what open(..., 'rb') 

396 # returns, so it's what we have to use at runtime. 

397 if isinstance(self.uri, io.IOBase): # type: ignore 

398 key = BinaryIO 

399 else: 

400 key = type(self.uri) 

401 return HELPER_REGISTRY[key] 

402 

403 def readHeader(self) -> Optional[str]: 

404 type_ = self._determineLoader() 

405 return type_.dumpHeader(self.uri, self.minimumVersion)