Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 30%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

161 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("LoadHelper", ) 

24 

25from lsst.resources import ResourcePath 

26from lsst.daf.butler import Quantum 

27from lsst.resources.s3 import S3ResourcePath 

28from lsst.resources.file import FileResourcePath 

29 

30from ..pipeline import TaskDef 

31from .quantumNode import NodeId 

32 

33from dataclasses import dataclass 

34import functools 

35import io 

36import json 

37import lzma 

38import pickle 

39import struct 

40 

41from collections import defaultdict, UserDict 

42from typing import (Optional, Iterable, DefaultDict, Set, Dict, TYPE_CHECKING, Tuple, Type, Union) 

43 

44if TYPE_CHECKING: 44 ↛ 45line 44 didn't jump to line 45, because the condition on line 44 was never true

45 from . import QuantumGraph 

46 

47 

48# Create a custom dict that will return the desired default if a key is missing 

49class RegistryDict(UserDict): 

50 def __missing__(self, key): 

51 return DefaultLoadHelper 

52 

53 

54# Create a registry to hold all the load Helper classes 

55HELPER_REGISTRY = RegistryDict() 

56 

57 

58def register_helper(URICLass: Union[Type[ResourcePath], Type[io.IO[bytes]]]): 

59 """Used to register classes as Load helpers 

60 

61 When decorating a class the parameter is the class of "handle type", i.e. 

62 a ResourcePath type or open file handle that will be used to do the 

63 loading. This is then associated with the decorated class such that when 

64 the parameter type is used to load data, the appropriate helper to work 

65 with that data type can be returned. 

66 

67 A decorator is used so that in theory someone could define another handler 

68 in a different module and register it for use. 

69 

70 Parameters 

71 ---------- 

72 URIClass : Type of `~lsst.resources.ResourcePath` or `~io.IO` of bytes 

73 type for which the decorated class should be mapped to 

74 """ 

75 def wrapper(class_): 

76 HELPER_REGISTRY[URICLass] = class_ 

77 return class_ 

78 return wrapper 

79 

80 

81class DefaultLoadHelper: 

82 """Default load helper for `QuantumGraph` save files 

83 

84 This class, and its subclasses, are used to unpack a quantum graph save 

85 file. This file is a binary representation of the graph in a format that 

86 allows individual nodes to be loaded without needing to load the entire 

87 file. 

88 

89 This default implementation has the interface to load select nodes 

90 from disk, but actually always loads the entire save file and simply 

91 returns what nodes (or all) are requested. This is intended to serve for 

92 all cases where there is a read method on the input parameter, but it is 

93 unknown how to read select bytes of the stream. It is the responsibility of 

94 sub classes to implement the method responsible for loading individual 

95 bytes from the stream. 

96 

97 Parameters 

98 ---------- 

99 uriObject : `~lsst.resources.ResourcePath` or `io.IO` of bytes 

100 This is the object that will be used to retrieve the raw bytes of the 

101 save. 

102 

103 Raises 

104 ------ 

105 ValueError 

106 Raised if the specified file contains the wrong file signature and is 

107 not a `QuantumGraph` save 

108 """ 

109 def __init__(self, uriObject: Union[ResourcePath, io.IO[bytes]]): 

110 self.uriObject = uriObject 

111 

112 # The length of infoSize will either be a tuple with length 2, 

113 # (version 1) which contains the lengths of 2 independent pickles, 

114 # or a tuple of length 1 which contains the total length of the entire 

115 # header information (minus the magic bytes and version bytes) 

116 preambleSize, infoSize = self._readSizes() 

117 

118 # Recode the total header size 

119 if self.save_version == 1: 

120 self.headerSize = preambleSize + infoSize[0] + infoSize[1] 

121 elif self.save_version == 2: 

122 self.headerSize = preambleSize + infoSize[0] 

123 else: 

124 raise ValueError(f"Unable to load QuantumGraph with version {self.save_version}, " 

125 "please try a newer version of the code.") 

126 

127 self._readByteMappings(preambleSize, self.headerSize, infoSize) 

128 

129 def _readSizes(self) -> Tuple[int, Tuple[int, ...]]: 

130 # need to import here to avoid cyclic imports 

131 from .graph import STRUCT_FMT_BASE, MAGIC_BYTES, STRUCT_FMT_STRING, SAVE_VERSION 

132 # Read the first few bytes which correspond to the lengths of the 

133 # magic identifier bytes, 2 byte version 

134 # number and the two 8 bytes numbers that are the sizes of the byte 

135 # maps 

136 magicSize = len(MAGIC_BYTES) 

137 

138 # read in just the fmt base to determine the save version 

139 fmtSize = struct.calcsize(STRUCT_FMT_BASE) 

140 preambleSize = magicSize + fmtSize 

141 

142 headerBytes = self._readBytes(0, preambleSize) 

143 magic = headerBytes[:magicSize] 

144 versionBytes = headerBytes[magicSize:] 

145 

146 if magic != MAGIC_BYTES: 

147 raise ValueError("This file does not appear to be a quantum graph save got magic bytes " 

148 f"{magic}, expected {MAGIC_BYTES}") 

149 

150 # Turn they encode bytes back into a python int object 

151 save_version, = struct.unpack(STRUCT_FMT_BASE, versionBytes) 

152 

153 if save_version > SAVE_VERSION: 

154 raise RuntimeError(f"The version of this save file is {save_version}, but this version of" 

155 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}") 

156 

157 # read in the next bits 

158 fmtString = STRUCT_FMT_STRING[save_version] 

159 infoSize = struct.calcsize(fmtString) 

160 infoBytes = self._readBytes(preambleSize, preambleSize+infoSize) 

161 infoUnpack = struct.unpack(fmtString, infoBytes) 

162 

163 preambleSize += infoSize 

164 

165 # Store the save version, so future read codes can make use of any 

166 # format changes to the save protocol 

167 self.save_version = save_version 

168 

169 return preambleSize, infoUnpack 

170 

171 def _readByteMappings(self, preambleSize: int, headerSize: int, infoSize: Tuple[int, ...]) -> None: 

172 # Take the header size explicitly so subclasses can modify before 

173 # This task is called 

174 

175 # read the bytes of taskDef bytes and nodes skipping the size bytes 

176 headerMaps = self._readBytes(preambleSize, headerSize) 

177 

178 if self.save_version == 1: 

179 taskDefSize, _ = infoSize 

180 

181 # read the map of taskDef bytes back in skipping the size bytes 

182 self.taskDefMap = pickle.loads(headerMaps[:taskDefSize]) 

183 

184 # read back in the graph id 

185 self._buildId = self.taskDefMap['__GraphBuildID'] 

186 

187 # read the map of the node objects back in skipping bytes 

188 # corresponding to the taskDef byte map 

189 self.map = pickle.loads(headerMaps[taskDefSize:]) 

190 

191 # There is no metadata for old versions 

192 self.metadata = None 

193 elif self.save_version == 2: 

194 uncompressedHeaderMap = lzma.decompress(headerMaps) 

195 header = json.loads(uncompressedHeaderMap) 

196 self.taskDefMap = header['TaskDefs'] 

197 self._buildId = header['GraphBuildID'] 

198 self.map = dict(header['Nodes']) 

199 self.metadata = header['Metadata'] 

200 else: 

201 raise ValueError(f"Unable to load QuantumGraph with version {self.save_version}, " 

202 "please try a newer version of the code.") 

203 

204 def load(self, nodes: Optional[Iterable[int]] = None, graphID: Optional[str] = None) -> QuantumGraph: 

205 """Loads in the specified nodes from the graph 

206 

207 Load in the `QuantumGraph` containing only the nodes specified in the 

208 ``nodes`` parameter from the graph specified at object creation. If 

209 ``nodes`` is None (the default) the whole graph is loaded. 

210 

211 Parameters 

212 ---------- 

213 nodes : `Iterable` of `int` or `None` 

214 The nodes to load from the graph, loads all if value is None 

215 (the default) 

216 graphID : `str` or `None` 

217 If specified this ID is verified against the loaded graph prior to 

218 loading any Nodes. This defaults to None in which case no 

219 validation is done. 

220 

221 Returns 

222 ------- 

223 graph : `QuantumGraph` 

224 The loaded `QuantumGraph` object 

225 

226 Raises 

227 ------ 

228 ValueError 

229 Raised if one or more of the nodes requested is not in the 

230 `QuantumGraph` or if graphID parameter does not match the graph 

231 being loaded. 

232 """ 

233 # need to import here to avoid cyclic imports 

234 from . import QuantumGraph 

235 if graphID is not None and self._buildId != graphID: 

236 raise ValueError('graphID does not match that of the graph being loaded') 

237 # Read in specified nodes, or all the nodes 

238 if nodes is None: 

239 nodes = list(self.map.keys()) 

240 # if all nodes are to be read, force the reader from the base class 

241 # that will read all they bytes in one go 

242 _readBytes = functools.partial(DefaultLoadHelper._readBytes, self) 

243 else: 

244 # only some bytes are being read using the reader specialized for 

245 # this class 

246 # create a set to ensure nodes are only loaded once 

247 nodes = set(nodes) 

248 # verify that all nodes requested are in the graph 

249 remainder = nodes - self.map.keys() 

250 if remainder: 

251 raise ValueError("Nodes {remainder} were requested, but could not be found in the input " 

252 "graph") 

253 _readBytes = self._readBytes 

254 # create a container for loaded data 

255 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set) 

256 quantumToNodeId: Dict[Quantum, NodeId] = {} 

257 loadedTaskDef = {} 

258 # loop over the nodes specified above 

259 for node in nodes: 

260 # Get the bytes to read from the map 

261 if self.save_version == 1: 

262 start, stop = self.map[node] 

263 else: 

264 start, stop = self.map[node]['bytes'] 

265 start += self.headerSize 

266 stop += self.headerSize 

267 

268 # read the specified bytes, will be overloaded by subclasses 

269 # bytes are compressed, so decompress them 

270 dump = lzma.decompress(_readBytes(start, stop)) 

271 

272 # reconstruct node 

273 qNode = pickle.loads(dump) 

274 

275 # read the saved node, name. If it has been loaded, attach it, if 

276 # not read in the taskDef first, and then load it 

277 nodeTask = qNode.taskDef 

278 if nodeTask not in loadedTaskDef: 

279 # Get the byte ranges corresponding to this taskDef 

280 if self.save_version == 1: 

281 start, stop = self.taskDefMap[nodeTask] 

282 else: 

283 start, stop = self.taskDefMap[nodeTask]['bytes'] 

284 start += self.headerSize 

285 stop += self.headerSize 

286 

287 # load the taskDef, this method call will be overloaded by 

288 # subclasses. 

289 # bytes are compressed, so decompress them 

290 taskDef = pickle.loads(lzma.decompress(_readBytes(start, stop))) 

291 loadedTaskDef[nodeTask] = taskDef 

292 # Explicitly overload the "frozen-ness" of nodes to attach the 

293 # taskDef back into the un-persisted node 

294 object.__setattr__(qNode, 'taskDef', loadedTaskDef[nodeTask]) 

295 quanta[qNode.taskDef].add(qNode.quantum) 

296 

297 # record the node for later processing 

298 quantumToNodeId[qNode.quantum] = qNode.nodeId 

299 

300 # construct an empty new QuantumGraph object, and run the associated 

301 # creation method with the un-persisted data 

302 qGraph = object.__new__(QuantumGraph) 

303 qGraph._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=self._buildId, 

304 metadata=self.metadata) 

305 return qGraph 

306 

307 def _readBytes(self, start: int, stop: int) -> bytes: 

308 """Loads the specified byte range from the ResourcePath object 

309 

310 In the base class, this actually will read all the bytes into a buffer 

311 from the specified ResourcePath object. Then for each method call will 

312 return the requested byte range. This is the most flexible 

313 implementation, as no special read is required. This will not give a 

314 speed up with any sub graph reads though. 

315 """ 

316 if not hasattr(self, 'buffer'): 

317 self.buffer = self.uriObject.read() 

318 return self.buffer[start:stop] 

319 

320 def close(self): 

321 """Cleans up an instance if needed. Base class does nothing 

322 """ 

323 pass 

324 

325 

326@register_helper(S3ResourcePath) 

327class S3LoadHelper(DefaultLoadHelper): 

328 # This subclass implements partial loading of a graph using a s3 uri 

329 def _readBytes(self, start: int, stop: int) -> bytes: 

330 args = {} 

331 # minus 1 in the stop range, because this header is inclusive rather 

332 # than standard python where the end point is generally exclusive 

333 args["Range"] = f"bytes={start}-{stop-1}" 

334 try: 

335 response = self.uriObject.client.get_object(Bucket=self.uriObject.netloc, 

336 Key=self.uriObject.relativeToPathRoot, 

337 **args) 

338 except (self.uriObject.client.exceptions.NoSuchKey, 

339 self.uriObject.client.exceptions.NoSuchBucket) as err: 

340 raise FileNotFoundError(f"No such resource: {self.uriObject}") from err 

341 body = response["Body"].read() 

342 response["Body"].close() 

343 return body 

344 

345 

346@register_helper(FileResourcePath) 

347class FileLoadHelper(DefaultLoadHelper): 

348 # This subclass implements partial loading of a graph using a file uri 

349 def _readBytes(self, start: int, stop: int) -> bytes: 

350 if not hasattr(self, 'fileHandle'): 

351 self.fileHandle = open(self.uriObject.ospath, 'rb') 

352 self.fileHandle.seek(start) 

353 return self.fileHandle.read(stop-start) 

354 

355 def close(self): 

356 if hasattr(self, 'fileHandle'): 

357 self.fileHandle.close() 

358 

359 

360@register_helper(io.IOBase) # type: ignore 

361class OpenFileHandleHelper(DefaultLoadHelper): 

362 # This handler is special in that it does not get initialized with a 

363 # ResourcePath, but an open file handle. 

364 

365 # Most everything stays the same, the variable is even stored as uriObject, 

366 # because the interface needed for reading is the same. Unfortunately 

367 # because we do not have Protocols yet, this can not be nicely expressed 

368 # with typing. 

369 

370 # This helper does support partial loading 

371 

372 def __init__(self, uriObject: io.IO[bytes]): 

373 # Explicitly annotate type and not infer from super 

374 self.uriObject: io.IO[bytes] 

375 super().__init__(uriObject) 

376 # This differs from the default __init__ to force the io object 

377 # back to the beginning so that in the case the entire file is to 

378 # read in the file is not already in a partially read state. 

379 self.uriObject.seek(0) 

380 

381 def _readBytes(self, start: int, stop: int) -> bytes: 

382 self.uriObject.seek(start) 

383 result = self.uriObject.read(stop-start) 

384 return result 

385 

386 

387@dataclass 

388class LoadHelper: 

389 """This is a helper class to assist with selecting the appropriate loader 

390 and managing any contexts that may be needed. 

391 

392 Note 

393 ---- 

394 This class may go away or be modified in the future if some of the 

395 features of this module can be propagated to 

396 `~lsst.resources.ResourcePath`. 

397 

398 This helper will raise a `ValueError` if the specified file does not appear 

399 to be a valid `QuantumGraph` save file. 

400 """ 

401 uri: ResourcePath 

402 """ResourcePath object from which the `QuantumGraph` is to be loaded 

403 """ 

404 def __enter__(self): 

405 # Only one handler is registered for anything that is an instance of 

406 # IOBase, so if any type is a subtype of that, set the key explicitly 

407 # so the correct loader is found, otherwise index by the type 

408 if isinstance(self.uri, io.IOBase): 

409 key = io.IOBase 

410 else: 

411 key = type(self.uri) 

412 self._loaded = HELPER_REGISTRY[key](self.uri) 

413 return self._loaded 

414 

415 def __exit__(self, type, value, traceback): 

416 self._loaded.close()