Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("LoadHelper") 

24 

25from lsst.daf.butler import ButlerURI, Quantum 

26from lsst.daf.butler.core._butlerUri.s3 import ButlerS3URI 

27from lsst.daf.butler.core._butlerUri.file import ButlerFileURI 

28 

29from ..pipeline import TaskDef 

30from .quantumNode import NodeId 

31 

32from dataclasses import dataclass 

33import functools 

34import io 

35import lzma 

36import pickle 

37import struct 

38 

39from collections import defaultdict, UserDict 

40from typing import (Optional, Iterable, DefaultDict, Set, Dict, TYPE_CHECKING, Type, Union) 

41 

42if TYPE_CHECKING: 42 ↛ 43line 42 didn't jump to line 43, because the condition on line 42 was never true

43 from . import QuantumGraph 

44 

45 

46# Create a custom dict that will return the desired default if a key is missing 

47class RegistryDict(UserDict): 

48 def __missing__(self, key): 

49 return DefaultLoadHelper 

50 

51 

52# Create a registry to hold all the load Helper classes 

53HELPER_REGISTRY = RegistryDict() 

54 

55 

56def register_helper(URICLass: Union[Type[ButlerURI], Type[io.IO[bytes]]]): 

57 """Used to register classes as Load helpers 

58 

59 When decorating a class the parameter is the class of "handle type", i.e. 

60 a ButlerURI type or open file handle that will be used to do the loading. 

61 This is then associated with the decorated class such that when the 

62 parameter type is used to load data, the appropriate helper to work with 

63 that data type can be returned. 

64 

65 A decorator is used so that in theory someone could define another handler 

66 in a different module and register it for use. 

67 

68 Parameters 

69 ---------- 

70 URIClass : Type of `~lsst.daf.butler.ButlerURI` or `~io.IO` of bytes 

71 type for which the decorated class should be mapped to 

72 """ 

73 def wrapper(class_): 

74 HELPER_REGISTRY[URICLass] = class_ 

75 return class_ 

76 return wrapper 

77 

78 

79class DefaultLoadHelper: 

80 """Default load helper for `QuantumGraph` save files 

81 

82 This class, and its subclasses, are used to unpack a quantum graph save 

83 file. This file is a binary representation of the graph in a format that 

84 allows individual nodes to be loaded without needing to load the entire 

85 file. 

86 

87 This default implementation has the interface to load select nodes 

88 from disk, but actually always loads the entire save file and simply 

89 returns what nodes (or all) are requested. This is intended to serve for 

90 all cases where there is a read method on the input parameter, but it is 

91 unknown how to read select bytes of the stream. It is the responsibility of 

92 sub classes to implement the method responsible for loading individual 

93 bytes from the stream. 

94 

95 Parameters 

96 ---------- 

97 uriObject : `~lsst.daf.butler.ButlerURI` or `io.IO` of bytes 

98 This is the object that will be used to retrieve the raw bytes of the 

99 save. 

100 

101 Raises 

102 ------ 

103 ValueError 

104 Raised if the specified file contains the wrong file signature and is 

105 not a `QuantumGraph` save 

106 """ 

107 def __init__(self, uriObject: Union[ButlerURI, io.IO[bytes]]): 

108 self.uriObject = uriObject 

109 

110 preambleSize, taskDefSize, nodeSize = self._readSizes() 

111 

112 # Recode the total header size 

113 self.headerSize = preambleSize + taskDefSize + nodeSize 

114 

115 self._readByteMappings(preambleSize, self.headerSize, taskDefSize) 

116 

117 def _readSizes(self): 

118 # need to import here to avoid cyclic imports 

119 from .graph import STRUCT_FMT_STRING, MAGIC_BYTES 

120 # Read the first few bytes which correspond to the lengths of the 

121 # magic identifier bytes, 2 byte version 

122 # number and the two 8 bytes numbers that are the sizes of the byte 

123 # maps 

124 magicSize = len(MAGIC_BYTES) 

125 fmt = STRUCT_FMT_STRING 

126 fmtSize = struct.calcsize(fmt) 

127 preambleSize = magicSize + fmtSize 

128 

129 headerBytes = self._readBytes(0, preambleSize) 

130 magic = headerBytes[:magicSize] 

131 sizeBytes = headerBytes[magicSize:] 

132 

133 if magic != MAGIC_BYTES: 

134 raise ValueError("This file does not appear to be a quantum graph save got magic bytes " 

135 f"{magic}, expected {MAGIC_BYTES}") 

136 

137 # Turn they encode bytes back into a python int object 

138 save_version, taskDefSize, nodeSize = struct.unpack('>HQQ', sizeBytes) 

139 

140 # Store the save version, so future read codes can make use of any 

141 # format changes to the save protocol 

142 self.save_version = save_version 

143 

144 return preambleSize, taskDefSize, nodeSize 

145 

146 def _readByteMappings(self, preambleSize, headerSize, taskDefSize): 

147 # Take the header size explicitly so subclasses can modify before 

148 # This task is called 

149 

150 # read the bytes of taskDef bytes and nodes skipping the size bytes 

151 headerMaps = self._readBytes(preambleSize, headerSize) 

152 

153 # read the map of taskDef bytes back in skipping the size bytes 

154 self.taskDefMap = pickle.loads(headerMaps[:taskDefSize]) 

155 

156 # read back in the graph id 

157 self._buildId = self.taskDefMap['__GraphBuildID'] 

158 

159 # read the map of the node objects back in skipping bytes 

160 # corresponding to the taskDef byte map 

161 self.map = pickle.loads(headerMaps[taskDefSize:]) 

162 

163 def load(self, nodes: Optional[Iterable[int]] = None, graphID: Optional[str] = None) -> QuantumGraph: 

164 """Loads in the specified nodes from the graph 

165 

166 Load in the `QuantumGraph` containing only the nodes specified in the 

167 ``nodes`` parameter from the graph specified at object creation. If 

168 ``nodes`` is None (the default) the whole graph is loaded. 

169 

170 Parameters 

171 ---------- 

172 nodes : `Iterable` of `int` or `None` 

173 The nodes to load from the graph, loads all if value is None 

174 (the default) 

175 graphID : `str` or `None` 

176 If specified this ID is verified against the loaded graph prior to 

177 loading any Nodes. This defaults to None in which case no 

178 validation is done. 

179 

180 Returns 

181 ------- 

182 graph : `QuantumGraph` 

183 The loaded `QuantumGraph` object 

184 

185 Raises 

186 ------ 

187 ValueError 

188 Raised if one or more of the nodes requested is not in the 

189 `QuantumGraph` or if graphID parameter does not match the graph 

190 being loaded. 

191 """ 

192 # need to import here to avoid cyclic imports 

193 from . import QuantumGraph 

194 if graphID is not None and self._buildId != graphID: 

195 raise ValueError('graphID does not match that of the graph being loaded') 

196 # Read in specified nodes, or all the nodes 

197 if nodes is None: 

198 nodes = list(self.map.keys()) 

199 # if all nodes are to be read, force the reader from the base class 

200 # that will read all they bytes in one go 

201 _readBytes = functools.partial(DefaultLoadHelper._readBytes, self) 

202 else: 

203 # only some bytes are being read using the reader specialized for 

204 # this class 

205 # create a set to ensure nodes are only loaded once 

206 nodes = set(nodes) 

207 # verify that all nodes requested are in the graph 

208 remainder = nodes - self.map.keys() 

209 if remainder: 

210 raise ValueError("Nodes {remainder} were requested, but could not be found in the input " 

211 "graph") 

212 _readBytes = self._readBytes 

213 # create a container for loaded data 

214 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set) 

215 quantumToNodeId: Dict[Quantum, NodeId] = {} 

216 loadedTaskDef = {} 

217 # loop over the nodes specified above 

218 for node in nodes: 

219 # Get the bytes to read from the map 

220 start, stop = self.map[node] 

221 start += self.headerSize 

222 stop += self.headerSize 

223 

224 # read the specified bytes, will be overloaded by subclasses 

225 # bytes are compressed, so decompress them 

226 dump = lzma.decompress(_readBytes(start, stop)) 

227 

228 # reconstruct node 

229 qNode = pickle.loads(dump) 

230 

231 # read the saved node, name. If it has been loaded, attach it, if 

232 # not read in the taskDef first, and then load it 

233 nodeTask = qNode.taskDef 

234 if nodeTask not in loadedTaskDef: 

235 # Get the byte ranges corresponding to this taskDef 

236 start, stop = self.taskDefMap[nodeTask] 

237 start += self.headerSize 

238 stop += self.headerSize 

239 

240 # load the taskDef, this method call will be overloaded by 

241 # subclasses. 

242 # bytes are compressed, so decompress them 

243 taskDef = pickle.loads(lzma.decompress(_readBytes(start, stop))) 

244 loadedTaskDef[nodeTask] = taskDef 

245 # Explicitly overload the "frozen-ness" of nodes to attach the 

246 # taskDef back into the un-persisted node 

247 object.__setattr__(qNode, 'taskDef', loadedTaskDef[nodeTask]) 

248 quanta[qNode.taskDef].add(qNode.quantum) 

249 

250 # record the node for later processing 

251 quantumToNodeId[qNode.quantum] = qNode.nodeId 

252 

253 # construct an empty new QuantumGraph object, and run the associated 

254 # creation method with the un-persisted data 

255 qGraph = object.__new__(QuantumGraph) 

256 qGraph._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=self._buildId) 

257 return qGraph 

258 

259 def _readBytes(self, start: int, stop: int) -> bytes: 

260 """Loads the specified byte range from the ButlerURI object 

261 

262 In the base class, this actually will read all the bytes into a buffer 

263 from the specified ButlerURI object. Then for each method call will 

264 return the requested byte range. This is the most flexible 

265 implementation, as no special read is required. This will not give a 

266 speed up with any sub graph reads though. 

267 """ 

268 if not hasattr(self, 'buffer'): 

269 self.buffer = self.uriObject.read() 

270 return self.buffer[start:stop] 

271 

272 def close(self): 

273 """Cleans up an instance if needed. Base class does nothing 

274 """ 

275 pass 

276 

277 

278@register_helper(ButlerS3URI) 

279class S3LoadHelper(DefaultLoadHelper): 

280 # This subclass implements partial loading of a graph using a s3 uri 

281 def _readBytes(self, start: int, stop: int) -> bytes: 

282 args = {} 

283 # minus 1 in the stop range, because this header is inclusive rather 

284 # than standard python where the end point is generally exclusive 

285 args["Range"] = f"bytes={start}-{stop-1}" 

286 try: 

287 response = self.uriObject.client.get_object(Bucket=self.uriObject.netloc, 

288 Key=self.uriObject.relativeToPathRoot, 

289 **args) 

290 except (self.uriObject.client.exceptions.NoSuchKey, 

291 self.uriObject.client.exceptions.NoSuchBucket) as err: 

292 raise FileNotFoundError(f"No such resource: {self.uriObject}") from err 

293 body = response["Body"].read() 

294 response["Body"].close() 

295 return body 

296 

297 

298@register_helper(ButlerFileURI) 

299class FileLoadHelper(DefaultLoadHelper): 

300 # This subclass implements partial loading of a graph using a file uri 

301 def _readBytes(self, start: int, stop: int) -> bytes: 

302 if not hasattr(self, 'fileHandle'): 

303 self.fileHandle = open(self.uriObject.ospath, 'rb') 

304 self.fileHandle.seek(start) 

305 return self.fileHandle.read(stop-start) 

306 

307 def close(self): 

308 if hasattr(self, 'fileHandle'): 

309 self.fileHandle.close() 

310 

311 

312@register_helper(io.IOBase) # type: ignore 

313class OpenFileHandleHelper(DefaultLoadHelper): 

314 # This handler is special in that it does not get initialized with a 

315 # ButlerURI, but an open file handle. 

316 

317 # Most everything stays the same, the variable is even stored as uriObject, 

318 # because the interface needed for reading is the same. Unfortunately 

319 # because we do not have Protocols yet, this can not be nicely expressed 

320 # with typing. 

321 

322 # This helper does support partial loading 

323 

324 def __init__(self, uriObject: io.IO[bytes]): 

325 # Explicitly annotate type and not infer from super 

326 self.uriObject: io.IO[bytes] 

327 super().__init__(uriObject) 

328 # This differs from the default __init__ to force the io object 

329 # back to the beginning so that in the case the entire file is to 

330 # read in the file is not already in a partially read state. 

331 self.uriObject.seek(0) 

332 

333 def _readBytes(self, start: int, stop: int) -> bytes: 

334 self.uriObject.seek(start) 

335 result = self.uriObject.read(stop-start) 

336 return result 

337 

338 

339@dataclass 

340class LoadHelper: 

341 """This is a helper class to assist with selecting the appropriate loader 

342 and managing any contexts that may be needed. 

343 

344 Note 

345 ---- 

346 This class may go away or be modified in the future if some of the 

347 features of this module can be propagated to `~lsst.daf.butler.ButlerURI`. 

348 

349 This helper will raise a `ValueError` if the specified file does not appear 

350 to be a valid `QuantumGraph` save file. 

351 """ 

352 uri: ButlerURI 

353 """ButlerURI object from which the `QuantumGraph` is to be loaded 

354 """ 

355 def __enter__(self): 

356 # Only one handler is registered for anything that is an instance of 

357 # IOBase, so if any type is a subtype of that, set the key explicitly 

358 # so the correct loader is found, otherwise index by the type 

359 if isinstance(self.uri, io.IOBase): 

360 key = io.IOBase 

361 else: 

362 key = type(self.uri) 

363 self._loaded = HELPER_REGISTRY[key](self.uri) 

364 return self._loaded 

365 

366 def __exit__(self, type, value, traceback): 

367 self._loaded.close()