Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 29%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

160 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("LoadHelper", ) 

24 

25from lsst.daf.butler import ButlerURI, Quantum 

26from lsst.daf.butler.core._butlerUri.s3 import ButlerS3URI 

27from lsst.daf.butler.core._butlerUri.file import ButlerFileURI 

28 

29from ..pipeline import TaskDef 

30from .quantumNode import NodeId 

31 

32from dataclasses import dataclass 

33import functools 

34import io 

35import json 

36import lzma 

37import pickle 

38import struct 

39 

40from collections import defaultdict, UserDict 

41from typing import (Optional, Iterable, DefaultDict, Set, Dict, TYPE_CHECKING, Tuple, Type, Union) 

42 

43if TYPE_CHECKING: 43 ↛ 44line 43 didn't jump to line 44, because the condition on line 43 was never true

44 from . import QuantumGraph 

45 

46 

47# Create a custom dict that will return the desired default if a key is missing 

48class RegistryDict(UserDict): 

49 def __missing__(self, key): 

50 return DefaultLoadHelper 

51 

52 

53# Create a registry to hold all the load Helper classes 

54HELPER_REGISTRY = RegistryDict() 

55 

56 

57def register_helper(URICLass: Union[Type[ButlerURI], Type[io.IO[bytes]]]): 

58 """Used to register classes as Load helpers 

59 

60 When decorating a class the parameter is the class of "handle type", i.e. 

61 a ButlerURI type or open file handle that will be used to do the loading. 

62 This is then associated with the decorated class such that when the 

63 parameter type is used to load data, the appropriate helper to work with 

64 that data type can be returned. 

65 

66 A decorator is used so that in theory someone could define another handler 

67 in a different module and register it for use. 

68 

69 Parameters 

70 ---------- 

71 URIClass : Type of `~lsst.daf.butler.ButlerURI` or `~io.IO` of bytes 

72 type for which the decorated class should be mapped to 

73 """ 

74 def wrapper(class_): 

75 HELPER_REGISTRY[URICLass] = class_ 

76 return class_ 

77 return wrapper 

78 

79 

80class DefaultLoadHelper: 

81 """Default load helper for `QuantumGraph` save files 

82 

83 This class, and its subclasses, are used to unpack a quantum graph save 

84 file. This file is a binary representation of the graph in a format that 

85 allows individual nodes to be loaded without needing to load the entire 

86 file. 

87 

88 This default implementation has the interface to load select nodes 

89 from disk, but actually always loads the entire save file and simply 

90 returns what nodes (or all) are requested. This is intended to serve for 

91 all cases where there is a read method on the input parameter, but it is 

92 unknown how to read select bytes of the stream. It is the responsibility of 

93 sub classes to implement the method responsible for loading individual 

94 bytes from the stream. 

95 

96 Parameters 

97 ---------- 

98 uriObject : `~lsst.daf.butler.ButlerURI` or `io.IO` of bytes 

99 This is the object that will be used to retrieve the raw bytes of the 

100 save. 

101 

102 Raises 

103 ------ 

104 ValueError 

105 Raised if the specified file contains the wrong file signature and is 

106 not a `QuantumGraph` save 

107 """ 

108 def __init__(self, uriObject: Union[ButlerURI, io.IO[bytes]]): 

109 self.uriObject = uriObject 

110 

111 # The length of infoSize will either be a tuple with length 2, 

112 # (version 1) which contains the lengths of 2 independent pickles, 

113 # or a tuple of length 1 which contains the total length of the entire 

114 # header information (minus the magic bytes and version bytes) 

115 preambleSize, infoSize = self._readSizes() 

116 

117 # Recode the total header size 

118 if self.save_version == 1: 

119 self.headerSize = preambleSize + infoSize[0] + infoSize[1] 

120 elif self.save_version == 2: 

121 self.headerSize = preambleSize + infoSize[0] 

122 else: 

123 raise ValueError(f"Unable to load QuantumGraph with version {self.save_version}, " 

124 "please try a newer version of the code.") 

125 

126 self._readByteMappings(preambleSize, self.headerSize, infoSize) 

127 

128 def _readSizes(self) -> Tuple[int, Tuple[int, ...]]: 

129 # need to import here to avoid cyclic imports 

130 from .graph import STRUCT_FMT_BASE, MAGIC_BYTES, STRUCT_FMT_STRING, SAVE_VERSION 

131 # Read the first few bytes which correspond to the lengths of the 

132 # magic identifier bytes, 2 byte version 

133 # number and the two 8 bytes numbers that are the sizes of the byte 

134 # maps 

135 magicSize = len(MAGIC_BYTES) 

136 

137 # read in just the fmt base to determine the save version 

138 fmtSize = struct.calcsize(STRUCT_FMT_BASE) 

139 preambleSize = magicSize + fmtSize 

140 

141 headerBytes = self._readBytes(0, preambleSize) 

142 magic = headerBytes[:magicSize] 

143 versionBytes = headerBytes[magicSize:] 

144 

145 if magic != MAGIC_BYTES: 

146 raise ValueError("This file does not appear to be a quantum graph save got magic bytes " 

147 f"{magic}, expected {MAGIC_BYTES}") 

148 

149 # Turn they encode bytes back into a python int object 

150 save_version, = struct.unpack(STRUCT_FMT_BASE, versionBytes) 

151 

152 if save_version > SAVE_VERSION: 

153 raise RuntimeError(f"The version of this save file is {save_version}, but this version of" 

154 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}") 

155 

156 # read in the next bits 

157 fmtString = STRUCT_FMT_STRING[save_version] 

158 infoSize = struct.calcsize(fmtString) 

159 infoBytes = self._readBytes(preambleSize, preambleSize+infoSize) 

160 infoUnpack = struct.unpack(fmtString, infoBytes) 

161 

162 preambleSize += infoSize 

163 

164 # Store the save version, so future read codes can make use of any 

165 # format changes to the save protocol 

166 self.save_version = save_version 

167 

168 return preambleSize, infoUnpack 

169 

170 def _readByteMappings(self, preambleSize: int, headerSize: int, infoSize: Tuple[int, ...]) -> None: 

171 # Take the header size explicitly so subclasses can modify before 

172 # This task is called 

173 

174 # read the bytes of taskDef bytes and nodes skipping the size bytes 

175 headerMaps = self._readBytes(preambleSize, headerSize) 

176 

177 if self.save_version == 1: 

178 taskDefSize, _ = infoSize 

179 

180 # read the map of taskDef bytes back in skipping the size bytes 

181 self.taskDefMap = pickle.loads(headerMaps[:taskDefSize]) 

182 

183 # read back in the graph id 

184 self._buildId = self.taskDefMap['__GraphBuildID'] 

185 

186 # read the map of the node objects back in skipping bytes 

187 # corresponding to the taskDef byte map 

188 self.map = pickle.loads(headerMaps[taskDefSize:]) 

189 

190 # There is no metadata for old versions 

191 self.metadata = None 

192 elif self.save_version == 2: 

193 uncompressedHeaderMap = lzma.decompress(headerMaps) 

194 header = json.loads(uncompressedHeaderMap) 

195 self.taskDefMap = header['TaskDefs'] 

196 self._buildId = header['GraphBuildID'] 

197 self.map = dict(header['Nodes']) 

198 self.metadata = header['Metadata'] 

199 else: 

200 raise ValueError(f"Unable to load QuantumGraph with version {self.save_version}, " 

201 "please try a newer version of the code.") 

202 

203 def load(self, nodes: Optional[Iterable[int]] = None, graphID: Optional[str] = None) -> QuantumGraph: 

204 """Loads in the specified nodes from the graph 

205 

206 Load in the `QuantumGraph` containing only the nodes specified in the 

207 ``nodes`` parameter from the graph specified at object creation. If 

208 ``nodes`` is None (the default) the whole graph is loaded. 

209 

210 Parameters 

211 ---------- 

212 nodes : `Iterable` of `int` or `None` 

213 The nodes to load from the graph, loads all if value is None 

214 (the default) 

215 graphID : `str` or `None` 

216 If specified this ID is verified against the loaded graph prior to 

217 loading any Nodes. This defaults to None in which case no 

218 validation is done. 

219 

220 Returns 

221 ------- 

222 graph : `QuantumGraph` 

223 The loaded `QuantumGraph` object 

224 

225 Raises 

226 ------ 

227 ValueError 

228 Raised if one or more of the nodes requested is not in the 

229 `QuantumGraph` or if graphID parameter does not match the graph 

230 being loaded. 

231 """ 

232 # need to import here to avoid cyclic imports 

233 from . import QuantumGraph 

234 if graphID is not None and self._buildId != graphID: 

235 raise ValueError('graphID does not match that of the graph being loaded') 

236 # Read in specified nodes, or all the nodes 

237 if nodes is None: 

238 nodes = list(self.map.keys()) 

239 # if all nodes are to be read, force the reader from the base class 

240 # that will read all they bytes in one go 

241 _readBytes = functools.partial(DefaultLoadHelper._readBytes, self) 

242 else: 

243 # only some bytes are being read using the reader specialized for 

244 # this class 

245 # create a set to ensure nodes are only loaded once 

246 nodes = set(nodes) 

247 # verify that all nodes requested are in the graph 

248 remainder = nodes - self.map.keys() 

249 if remainder: 

250 raise ValueError("Nodes {remainder} were requested, but could not be found in the input " 

251 "graph") 

252 _readBytes = self._readBytes 

253 # create a container for loaded data 

254 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set) 

255 quantumToNodeId: Dict[Quantum, NodeId] = {} 

256 loadedTaskDef = {} 

257 # loop over the nodes specified above 

258 for node in nodes: 

259 # Get the bytes to read from the map 

260 if self.save_version == 1: 

261 start, stop = self.map[node] 

262 else: 

263 start, stop = self.map[node]['bytes'] 

264 start += self.headerSize 

265 stop += self.headerSize 

266 

267 # read the specified bytes, will be overloaded by subclasses 

268 # bytes are compressed, so decompress them 

269 dump = lzma.decompress(_readBytes(start, stop)) 

270 

271 # reconstruct node 

272 qNode = pickle.loads(dump) 

273 

274 # read the saved node, name. If it has been loaded, attach it, if 

275 # not read in the taskDef first, and then load it 

276 nodeTask = qNode.taskDef 

277 if nodeTask not in loadedTaskDef: 

278 # Get the byte ranges corresponding to this taskDef 

279 if self.save_version == 1: 

280 start, stop = self.taskDefMap[nodeTask] 

281 else: 

282 start, stop = self.taskDefMap[nodeTask]['bytes'] 

283 start += self.headerSize 

284 stop += self.headerSize 

285 

286 # load the taskDef, this method call will be overloaded by 

287 # subclasses. 

288 # bytes are compressed, so decompress them 

289 taskDef = pickle.loads(lzma.decompress(_readBytes(start, stop))) 

290 loadedTaskDef[nodeTask] = taskDef 

291 # Explicitly overload the "frozen-ness" of nodes to attach the 

292 # taskDef back into the un-persisted node 

293 object.__setattr__(qNode, 'taskDef', loadedTaskDef[nodeTask]) 

294 quanta[qNode.taskDef].add(qNode.quantum) 

295 

296 # record the node for later processing 

297 quantumToNodeId[qNode.quantum] = qNode.nodeId 

298 

299 # construct an empty new QuantumGraph object, and run the associated 

300 # creation method with the un-persisted data 

301 qGraph = object.__new__(QuantumGraph) 

302 qGraph._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=self._buildId, 

303 metadata=self.metadata) 

304 return qGraph 

305 

306 def _readBytes(self, start: int, stop: int) -> bytes: 

307 """Loads the specified byte range from the ButlerURI object 

308 

309 In the base class, this actually will read all the bytes into a buffer 

310 from the specified ButlerURI object. Then for each method call will 

311 return the requested byte range. This is the most flexible 

312 implementation, as no special read is required. This will not give a 

313 speed up with any sub graph reads though. 

314 """ 

315 if not hasattr(self, 'buffer'): 

316 self.buffer = self.uriObject.read() 

317 return self.buffer[start:stop] 

318 

319 def close(self): 

320 """Cleans up an instance if needed. Base class does nothing 

321 """ 

322 pass 

323 

324 

325@register_helper(ButlerS3URI) 

326class S3LoadHelper(DefaultLoadHelper): 

327 # This subclass implements partial loading of a graph using a s3 uri 

328 def _readBytes(self, start: int, stop: int) -> bytes: 

329 args = {} 

330 # minus 1 in the stop range, because this header is inclusive rather 

331 # than standard python where the end point is generally exclusive 

332 args["Range"] = f"bytes={start}-{stop-1}" 

333 try: 

334 response = self.uriObject.client.get_object(Bucket=self.uriObject.netloc, 

335 Key=self.uriObject.relativeToPathRoot, 

336 **args) 

337 except (self.uriObject.client.exceptions.NoSuchKey, 

338 self.uriObject.client.exceptions.NoSuchBucket) as err: 

339 raise FileNotFoundError(f"No such resource: {self.uriObject}") from err 

340 body = response["Body"].read() 

341 response["Body"].close() 

342 return body 

343 

344 

345@register_helper(ButlerFileURI) 

346class FileLoadHelper(DefaultLoadHelper): 

347 # This subclass implements partial loading of a graph using a file uri 

348 def _readBytes(self, start: int, stop: int) -> bytes: 

349 if not hasattr(self, 'fileHandle'): 

350 self.fileHandle = open(self.uriObject.ospath, 'rb') 

351 self.fileHandle.seek(start) 

352 return self.fileHandle.read(stop-start) 

353 

354 def close(self): 

355 if hasattr(self, 'fileHandle'): 

356 self.fileHandle.close() 

357 

358 

359@register_helper(io.IOBase) # type: ignore 

360class OpenFileHandleHelper(DefaultLoadHelper): 

361 # This handler is special in that it does not get initialized with a 

362 # ButlerURI, but an open file handle. 

363 

364 # Most everything stays the same, the variable is even stored as uriObject, 

365 # because the interface needed for reading is the same. Unfortunately 

366 # because we do not have Protocols yet, this can not be nicely expressed 

367 # with typing. 

368 

369 # This helper does support partial loading 

370 

371 def __init__(self, uriObject: io.IO[bytes]): 

372 # Explicitly annotate type and not infer from super 

373 self.uriObject: io.IO[bytes] 

374 super().__init__(uriObject) 

375 # This differs from the default __init__ to force the io object 

376 # back to the beginning so that in the case the entire file is to 

377 # read in the file is not already in a partially read state. 

378 self.uriObject.seek(0) 

379 

380 def _readBytes(self, start: int, stop: int) -> bytes: 

381 self.uriObject.seek(start) 

382 result = self.uriObject.read(stop-start) 

383 return result 

384 

385 

386@dataclass 

387class LoadHelper: 

388 """This is a helper class to assist with selecting the appropriate loader 

389 and managing any contexts that may be needed. 

390 

391 Note 

392 ---- 

393 This class may go away or be modified in the future if some of the 

394 features of this module can be propagated to `~lsst.daf.butler.ButlerURI`. 

395 

396 This helper will raise a `ValueError` if the specified file does not appear 

397 to be a valid `QuantumGraph` save file. 

398 """ 

399 uri: ButlerURI 

400 """ButlerURI object from which the `QuantumGraph` is to be loaded 

401 """ 

402 def __enter__(self): 

403 # Only one handler is registered for anything that is an instance of 

404 # IOBase, so if any type is a subtype of that, set the key explicitly 

405 # so the correct loader is found, otherwise index by the type 

406 if isinstance(self.uri, io.IOBase): 

407 key = io.IOBase 

408 else: 

409 key = type(self.uri) 

410 self._loaded = HELPER_REGISTRY[key](self.uri) 

411 return self._loaded 

412 

413 def __exit__(self, type, value, traceback): 

414 self._loaded.close()