Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 30%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("LoadHelper", )
25from lsst.resources import ResourcePath
26from lsst.daf.butler import Quantum
27from lsst.resources.s3 import S3ResourcePath
28from lsst.resources.file import FileResourcePath
30from ..pipeline import TaskDef
31from .quantumNode import NodeId
33from dataclasses import dataclass
34import functools
35import io
36import json
37import lzma
38import pickle
39import struct
41from collections import defaultdict, UserDict
42from typing import (Optional, Iterable, DefaultDict, Set, Dict, TYPE_CHECKING, Tuple, Type, Union)
44if TYPE_CHECKING: 44 ↛ 45line 44 didn't jump to line 45, because the condition on line 44 was never true
45 from . import QuantumGraph
48# Create a custom dict that will return the desired default if a key is missing
49class RegistryDict(UserDict):
50 def __missing__(self, key):
51 return DefaultLoadHelper
54# Create a registry to hold all the load Helper classes
55HELPER_REGISTRY = RegistryDict()
58def register_helper(URICLass: Union[Type[ResourcePath], Type[io.IO[bytes]]]):
59 """Used to register classes as Load helpers
61 When decorating a class the parameter is the class of "handle type", i.e.
62 a ResourcePath type or open file handle that will be used to do the
63 loading. This is then associated with the decorated class such that when
64 the parameter type is used to load data, the appropriate helper to work
65 with that data type can be returned.
67 A decorator is used so that in theory someone could define another handler
68 in a different module and register it for use.
70 Parameters
71 ----------
72 URIClass : Type of `~lsst.resources.ResourcePath` or `~io.IO` of bytes
73 type for which the decorated class should be mapped to
74 """
75 def wrapper(class_):
76 HELPER_REGISTRY[URICLass] = class_
77 return class_
78 return wrapper
81class DefaultLoadHelper:
82 """Default load helper for `QuantumGraph` save files
84 This class, and its subclasses, are used to unpack a quantum graph save
85 file. This file is a binary representation of the graph in a format that
86 allows individual nodes to be loaded without needing to load the entire
87 file.
89 This default implementation has the interface to load select nodes
90 from disk, but actually always loads the entire save file and simply
91 returns what nodes (or all) are requested. This is intended to serve for
92 all cases where there is a read method on the input parameter, but it is
93 unknown how to read select bytes of the stream. It is the responsibility of
94 sub classes to implement the method responsible for loading individual
95 bytes from the stream.
97 Parameters
98 ----------
99 uriObject : `~lsst.resources.ResourcePath` or `io.IO` of bytes
100 This is the object that will be used to retrieve the raw bytes of the
101 save.
103 Raises
104 ------
105 ValueError
106 Raised if the specified file contains the wrong file signature and is
107 not a `QuantumGraph` save
108 """
109 def __init__(self, uriObject: Union[ResourcePath, io.IO[bytes]]):
110 self.uriObject = uriObject
112 # The length of infoSize will either be a tuple with length 2,
113 # (version 1) which contains the lengths of 2 independent pickles,
114 # or a tuple of length 1 which contains the total length of the entire
115 # header information (minus the magic bytes and version bytes)
116 preambleSize, infoSize = self._readSizes()
118 # Recode the total header size
119 if self.save_version == 1:
120 self.headerSize = preambleSize + infoSize[0] + infoSize[1]
121 elif self.save_version == 2:
122 self.headerSize = preambleSize + infoSize[0]
123 else:
124 raise ValueError(f"Unable to load QuantumGraph with version {self.save_version}, "
125 "please try a newer version of the code.")
127 self._readByteMappings(preambleSize, self.headerSize, infoSize)
129 def _readSizes(self) -> Tuple[int, Tuple[int, ...]]:
130 # need to import here to avoid cyclic imports
131 from .graph import STRUCT_FMT_BASE, MAGIC_BYTES, STRUCT_FMT_STRING, SAVE_VERSION
132 # Read the first few bytes which correspond to the lengths of the
133 # magic identifier bytes, 2 byte version
134 # number and the two 8 bytes numbers that are the sizes of the byte
135 # maps
136 magicSize = len(MAGIC_BYTES)
138 # read in just the fmt base to determine the save version
139 fmtSize = struct.calcsize(STRUCT_FMT_BASE)
140 preambleSize = magicSize + fmtSize
142 headerBytes = self._readBytes(0, preambleSize)
143 magic = headerBytes[:magicSize]
144 versionBytes = headerBytes[magicSize:]
146 if magic != MAGIC_BYTES:
147 raise ValueError("This file does not appear to be a quantum graph save got magic bytes "
148 f"{magic}, expected {MAGIC_BYTES}")
150 # Turn they encode bytes back into a python int object
151 save_version, = struct.unpack(STRUCT_FMT_BASE, versionBytes)
153 if save_version > SAVE_VERSION:
154 raise RuntimeError(f"The version of this save file is {save_version}, but this version of"
155 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}")
157 # read in the next bits
158 fmtString = STRUCT_FMT_STRING[save_version]
159 infoSize = struct.calcsize(fmtString)
160 infoBytes = self._readBytes(preambleSize, preambleSize+infoSize)
161 infoUnpack = struct.unpack(fmtString, infoBytes)
163 preambleSize += infoSize
165 # Store the save version, so future read codes can make use of any
166 # format changes to the save protocol
167 self.save_version = save_version
169 return preambleSize, infoUnpack
171 def _readByteMappings(self, preambleSize: int, headerSize: int, infoSize: Tuple[int, ...]) -> None:
172 # Take the header size explicitly so subclasses can modify before
173 # This task is called
175 # read the bytes of taskDef bytes and nodes skipping the size bytes
176 headerMaps = self._readBytes(preambleSize, headerSize)
178 if self.save_version == 1:
179 taskDefSize, _ = infoSize
181 # read the map of taskDef bytes back in skipping the size bytes
182 self.taskDefMap = pickle.loads(headerMaps[:taskDefSize])
184 # read back in the graph id
185 self._buildId = self.taskDefMap['__GraphBuildID']
187 # read the map of the node objects back in skipping bytes
188 # corresponding to the taskDef byte map
189 self.map = pickle.loads(headerMaps[taskDefSize:])
191 # There is no metadata for old versions
192 self.metadata = None
193 elif self.save_version == 2:
194 uncompressedHeaderMap = lzma.decompress(headerMaps)
195 header = json.loads(uncompressedHeaderMap)
196 self.taskDefMap = header['TaskDefs']
197 self._buildId = header['GraphBuildID']
198 self.map = dict(header['Nodes'])
199 self.metadata = header['Metadata']
200 else:
201 raise ValueError(f"Unable to load QuantumGraph with version {self.save_version}, "
202 "please try a newer version of the code.")
204 def load(self, nodes: Optional[Iterable[int]] = None, graphID: Optional[str] = None) -> QuantumGraph:
205 """Loads in the specified nodes from the graph
207 Load in the `QuantumGraph` containing only the nodes specified in the
208 ``nodes`` parameter from the graph specified at object creation. If
209 ``nodes`` is None (the default) the whole graph is loaded.
211 Parameters
212 ----------
213 nodes : `Iterable` of `int` or `None`
214 The nodes to load from the graph, loads all if value is None
215 (the default)
216 graphID : `str` or `None`
217 If specified this ID is verified against the loaded graph prior to
218 loading any Nodes. This defaults to None in which case no
219 validation is done.
221 Returns
222 -------
223 graph : `QuantumGraph`
224 The loaded `QuantumGraph` object
226 Raises
227 ------
228 ValueError
229 Raised if one or more of the nodes requested is not in the
230 `QuantumGraph` or if graphID parameter does not match the graph
231 being loaded.
232 """
233 # need to import here to avoid cyclic imports
234 from . import QuantumGraph
235 if graphID is not None and self._buildId != graphID:
236 raise ValueError('graphID does not match that of the graph being loaded')
237 # Read in specified nodes, or all the nodes
238 if nodes is None:
239 nodes = list(self.map.keys())
240 # if all nodes are to be read, force the reader from the base class
241 # that will read all they bytes in one go
242 _readBytes = functools.partial(DefaultLoadHelper._readBytes, self)
243 else:
244 # only some bytes are being read using the reader specialized for
245 # this class
246 # create a set to ensure nodes are only loaded once
247 nodes = set(nodes)
248 # verify that all nodes requested are in the graph
249 remainder = nodes - self.map.keys()
250 if remainder:
251 raise ValueError("Nodes {remainder} were requested, but could not be found in the input "
252 "graph")
253 _readBytes = self._readBytes
254 # create a container for loaded data
255 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
256 quantumToNodeId: Dict[Quantum, NodeId] = {}
257 loadedTaskDef = {}
258 # loop over the nodes specified above
259 for node in nodes:
260 # Get the bytes to read from the map
261 if self.save_version == 1:
262 start, stop = self.map[node]
263 else:
264 start, stop = self.map[node]['bytes']
265 start += self.headerSize
266 stop += self.headerSize
268 # read the specified bytes, will be overloaded by subclasses
269 # bytes are compressed, so decompress them
270 dump = lzma.decompress(_readBytes(start, stop))
272 # reconstruct node
273 qNode = pickle.loads(dump)
275 # read the saved node, name. If it has been loaded, attach it, if
276 # not read in the taskDef first, and then load it
277 nodeTask = qNode.taskDef
278 if nodeTask not in loadedTaskDef:
279 # Get the byte ranges corresponding to this taskDef
280 if self.save_version == 1:
281 start, stop = self.taskDefMap[nodeTask]
282 else:
283 start, stop = self.taskDefMap[nodeTask]['bytes']
284 start += self.headerSize
285 stop += self.headerSize
287 # load the taskDef, this method call will be overloaded by
288 # subclasses.
289 # bytes are compressed, so decompress them
290 taskDef = pickle.loads(lzma.decompress(_readBytes(start, stop)))
291 loadedTaskDef[nodeTask] = taskDef
292 # Explicitly overload the "frozen-ness" of nodes to attach the
293 # taskDef back into the un-persisted node
294 object.__setattr__(qNode, 'taskDef', loadedTaskDef[nodeTask])
295 quanta[qNode.taskDef].add(qNode.quantum)
297 # record the node for later processing
298 quantumToNodeId[qNode.quantum] = qNode.nodeId
300 # construct an empty new QuantumGraph object, and run the associated
301 # creation method with the un-persisted data
302 qGraph = object.__new__(QuantumGraph)
303 qGraph._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=self._buildId,
304 metadata=self.metadata)
305 return qGraph
307 def _readBytes(self, start: int, stop: int) -> bytes:
308 """Loads the specified byte range from the ResourcePath object
310 In the base class, this actually will read all the bytes into a buffer
311 from the specified ResourcePath object. Then for each method call will
312 return the requested byte range. This is the most flexible
313 implementation, as no special read is required. This will not give a
314 speed up with any sub graph reads though.
315 """
316 if not hasattr(self, 'buffer'):
317 self.buffer = self.uriObject.read()
318 return self.buffer[start:stop]
320 def close(self):
321 """Cleans up an instance if needed. Base class does nothing
322 """
323 pass
326@register_helper(S3ResourcePath)
327class S3LoadHelper(DefaultLoadHelper):
328 # This subclass implements partial loading of a graph using a s3 uri
329 def _readBytes(self, start: int, stop: int) -> bytes:
330 args = {}
331 # minus 1 in the stop range, because this header is inclusive rather
332 # than standard python where the end point is generally exclusive
333 args["Range"] = f"bytes={start}-{stop-1}"
334 try:
335 response = self.uriObject.client.get_object(Bucket=self.uriObject.netloc,
336 Key=self.uriObject.relativeToPathRoot,
337 **args)
338 except (self.uriObject.client.exceptions.NoSuchKey,
339 self.uriObject.client.exceptions.NoSuchBucket) as err:
340 raise FileNotFoundError(f"No such resource: {self.uriObject}") from err
341 body = response["Body"].read()
342 response["Body"].close()
343 return body
346@register_helper(FileResourcePath)
347class FileLoadHelper(DefaultLoadHelper):
348 # This subclass implements partial loading of a graph using a file uri
349 def _readBytes(self, start: int, stop: int) -> bytes:
350 if not hasattr(self, 'fileHandle'):
351 self.fileHandle = open(self.uriObject.ospath, 'rb')
352 self.fileHandle.seek(start)
353 return self.fileHandle.read(stop-start)
355 def close(self):
356 if hasattr(self, 'fileHandle'):
357 self.fileHandle.close()
360@register_helper(io.IOBase) # type: ignore
361class OpenFileHandleHelper(DefaultLoadHelper):
362 # This handler is special in that it does not get initialized with a
363 # ResourcePath, but an open file handle.
365 # Most everything stays the same, the variable is even stored as uriObject,
366 # because the interface needed for reading is the same. Unfortunately
367 # because we do not have Protocols yet, this can not be nicely expressed
368 # with typing.
370 # This helper does support partial loading
372 def __init__(self, uriObject: io.IO[bytes]):
373 # Explicitly annotate type and not infer from super
374 self.uriObject: io.IO[bytes]
375 super().__init__(uriObject)
376 # This differs from the default __init__ to force the io object
377 # back to the beginning so that in the case the entire file is to
378 # read in the file is not already in a partially read state.
379 self.uriObject.seek(0)
381 def _readBytes(self, start: int, stop: int) -> bytes:
382 self.uriObject.seek(start)
383 result = self.uriObject.read(stop-start)
384 return result
387@dataclass
388class LoadHelper:
389 """This is a helper class to assist with selecting the appropriate loader
390 and managing any contexts that may be needed.
392 Note
393 ----
394 This class may go away or be modified in the future if some of the
395 features of this module can be propagated to
396 `~lsst.resources.ResourcePath`.
398 This helper will raise a `ValueError` if the specified file does not appear
399 to be a valid `QuantumGraph` save file.
400 """
401 uri: ResourcePath
402 """ResourcePath object from which the `QuantumGraph` is to be loaded
403 """
404 def __enter__(self):
405 # Only one handler is registered for anything that is an instance of
406 # IOBase, so if any type is a subtype of that, set the key explicitly
407 # so the correct loader is found, otherwise index by the type
408 if isinstance(self.uri, io.IOBase):
409 key = io.IOBase
410 else:
411 key = type(self.uri)
412 self._loaded = HELPER_REGISTRY[key](self.uri)
413 return self._loaded
415 def __exit__(self, type, value, traceback):
416 self._loaded.close()