Coverage for python/lsst/pipe/base/graph/_loadHelpers.py : 34%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("LoadHelper")
25from lsst.daf.butler import ButlerURI, Quantum
26from lsst.daf.butler.core._butlerUri.s3 import ButlerS3URI
27from lsst.daf.butler.core._butlerUri.file import ButlerFileURI
29from ..pipeline import TaskDef
30from .quantumNode import NodeId
32from dataclasses import dataclass
33import functools
34import io
35import lzma
36import pickle
37import struct
39from collections import defaultdict, UserDict
40from typing import (Optional, Iterable, DefaultDict, Set, Dict, TYPE_CHECKING, Type, Union)
42if TYPE_CHECKING: 42 ↛ 43line 42 didn't jump to line 43, because the condition on line 42 was never true
43 from . import QuantumGraph
46# Create a custom dict that will return the desired default if a key is missing
47class RegistryDict(UserDict):
48 def __missing__(self, key):
49 return DefaultLoadHelper
52# Create a registry to hold all the load Helper classes
53HELPER_REGISTRY = RegistryDict()
56def register_helper(URICLass: Union[Type[ButlerURI], Type[io.IO[bytes]]]):
57 """Used to register classes as Load helpers
59 When decorating a class the parameter is the class of "handle type", i.e.
60 a ButlerURI type or open file handle that will be used to do the loading.
61 This is then associated with the decorated class such that when the
62 parameter type is used to load data, the appropriate helper to work with
63 that data type can be returned.
65 A decorator is used so that in theory someone could define another handler
66 in a different module and register it for use.
68 Parameters
69 ----------
70 URIClass : Type of `~lsst.daf.butler.ButlerURI` or `~io.IO` of bytes
71 type for which the decorated class should be mapped to
72 """
73 def wrapper(class_):
74 HELPER_REGISTRY[URICLass] = class_
75 return class_
76 return wrapper
79class DefaultLoadHelper:
80 """Default load helper for `QuantumGraph` save files
82 This class, and its subclasses, are used to unpack a quantum graph save
83 file. This file is a binary representation of the graph in a format that
84 allows individual nodes to be loaded without needing to load the entire
85 file.
87 This default implementation has the interface to load select nodes
88 from disk, but actually always loads the entire save file and simply
89 returns what nodes (or all) are requested. This is intended to serve for
90 all cases where there is a read method on the input parameter, but it is
91 unknown how to read select bytes of the stream. It is the responsibility of
92 sub classes to implement the method responsible for loading individual
93 bytes from the stream.
95 Parameters
96 ----------
97 uriObject : `~lsst.daf.butler.ButlerURI` or `io.IO` of bytes
98 This is the object that will be used to retrieve the raw bytes of the
99 save.
101 Raises
102 ------
103 ValueError
104 Raised if the specified file contains the wrong file signature and is
105 not a `QuantumGraph` save
106 """
107 def __init__(self, uriObject: Union[ButlerURI, io.IO[bytes]]):
108 self.uriObject = uriObject
110 preambleSize, taskDefSize, nodeSize = self._readSizes()
112 # Recode the total header size
113 self.headerSize = preambleSize + taskDefSize + nodeSize
115 self._readByteMappings(preambleSize, self.headerSize, taskDefSize)
117 def _readSizes(self):
118 # need to import here to avoid cyclic imports
119 from .graph import STRUCT_FMT_STRING, MAGIC_BYTES
120 # Read the first few bytes which correspond to the lengths of the
121 # magic identifier bytes, 2 byte version
122 # number and the two 8 bytes numbers that are the sizes of the byte
123 # maps
124 magicSize = len(MAGIC_BYTES)
125 fmt = STRUCT_FMT_STRING
126 fmtSize = struct.calcsize(fmt)
127 preambleSize = magicSize + fmtSize
129 headerBytes = self._readBytes(0, preambleSize)
130 magic = headerBytes[:magicSize]
131 sizeBytes = headerBytes[magicSize:]
133 if magic != MAGIC_BYTES:
134 raise ValueError("This file does not appear to be a quantum graph save got magic bytes "
135 f"{magic}, expected {MAGIC_BYTES}")
137 # Turn they encode bytes back into a python int object
138 save_version, taskDefSize, nodeSize = struct.unpack('>HQQ', sizeBytes)
140 # Store the save version, so future read codes can make use of any
141 # format changes to the save protocol
142 self.save_version = save_version
144 return preambleSize, taskDefSize, nodeSize
146 def _readByteMappings(self, preambleSize, headerSize, taskDefSize):
147 # Take the header size explicitly so subclasses can modify before
148 # This task is called
150 # read the bytes of taskDef bytes and nodes skipping the size bytes
151 headerMaps = self._readBytes(preambleSize, headerSize)
153 # read the map of taskDef bytes back in skipping the size bytes
154 self.taskDefMap = pickle.loads(headerMaps[:taskDefSize])
156 # read back in the graph id
157 self._buildId = self.taskDefMap['__GraphBuildID']
159 # read the map of the node objects back in skipping bytes
160 # corresponding to the taskDef byte map
161 self.map = pickle.loads(headerMaps[taskDefSize:])
163 def load(self, nodes: Optional[Iterable[int]] = None, graphID: Optional[str] = None) -> QuantumGraph:
164 """Loads in the specified nodes from the graph
166 Load in the `QuantumGraph` containing only the nodes specified in the
167 ``nodes`` parameter from the graph specified at object creation. If
168 ``nodes`` is None (the default) the whole graph is loaded.
170 Parameters
171 ----------
172 nodes : `Iterable` of `int` or `None`
173 The nodes to load from the graph, loads all if value is None
174 (the default)
175 graphID : `str` or `None`
176 If specified this ID is verified against the loaded graph prior to
177 loading any Nodes. This defaults to None in which case no
178 validation is done.
180 Returns
181 -------
182 graph : `QuantumGraph`
183 The loaded `QuantumGraph` object
185 Raises
186 ------
187 ValueError
188 Raised if one or more of the nodes requested is not in the
189 `QuantumGraph` or if graphID parameter does not match the graph
190 being loaded.
191 """
192 # need to import here to avoid cyclic imports
193 from . import QuantumGraph
194 if graphID is not None and self._buildId != graphID:
195 raise ValueError('graphID does not match that of the graph being loaded')
196 # Read in specified nodes, or all the nodes
197 if nodes is None:
198 nodes = list(self.map.keys())
199 # if all nodes are to be read, force the reader from the base class
200 # that will read all they bytes in one go
201 _readBytes = functools.partial(DefaultLoadHelper._readBytes, self)
202 else:
203 # only some bytes are being read using the reader specialized for
204 # this class
205 # create a set to ensure nodes are only loaded once
206 nodes = set(nodes)
207 # verify that all nodes requested are in the graph
208 remainder = nodes - self.map.keys()
209 if remainder:
210 raise ValueError("Nodes {remainder} were requested, but could not be found in the input "
211 "graph")
212 _readBytes = self._readBytes
213 # create a container for loaded data
214 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
215 quantumToNodeId: Dict[Quantum, NodeId] = {}
216 loadedTaskDef = {}
217 # loop over the nodes specified above
218 for node in nodes:
219 # Get the bytes to read from the map
220 start, stop = self.map[node]
221 start += self.headerSize
222 stop += self.headerSize
224 # read the specified bytes, will be overloaded by subclasses
225 # bytes are compressed, so decompress them
226 dump = lzma.decompress(_readBytes(start, stop))
228 # reconstruct node
229 qNode = pickle.loads(dump)
231 # read the saved node, name. If it has been loaded, attach it, if
232 # not read in the taskDef first, and then load it
233 nodeTask = qNode.taskDef
234 if nodeTask not in loadedTaskDef:
235 # Get the byte ranges corresponding to this taskDef
236 start, stop = self.taskDefMap[nodeTask]
237 start += self.headerSize
238 stop += self.headerSize
240 # load the taskDef, this method call will be overloaded by
241 # subclasses.
242 # bytes are compressed, so decompress them
243 taskDef = pickle.loads(lzma.decompress(_readBytes(start, stop)))
244 loadedTaskDef[nodeTask] = taskDef
245 # Explicitly overload the "frozen-ness" of nodes to attach the
246 # taskDef back into the un-persisted node
247 object.__setattr__(qNode, 'taskDef', loadedTaskDef[nodeTask])
248 quanta[qNode.taskDef].add(qNode.quantum)
250 # record the node for later processing
251 quantumToNodeId[qNode.quantum] = qNode.nodeId
253 # construct an empty new QuantumGraph object, and run the associated
254 # creation method with the un-persisted data
255 qGraph = object.__new__(QuantumGraph)
256 qGraph._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=self._buildId)
257 return qGraph
259 def _readBytes(self, start: int, stop: int) -> bytes:
260 """Loads the specified byte range from the ButlerURI object
262 In the base class, this actually will read all the bytes into a buffer
263 from the specified ButlerURI object. Then for each method call will
264 return the requested byte range. This is the most flexible
265 implementation, as no special read is required. This will not give a
266 speed up with any sub graph reads though.
267 """
268 if not hasattr(self, 'buffer'):
269 self.buffer = self.uriObject.read()
270 return self.buffer[start:stop]
272 def close(self):
273 """Cleans up an instance if needed. Base class does nothing
274 """
275 pass
278@register_helper(ButlerS3URI)
279class S3LoadHelper(DefaultLoadHelper):
280 # This subclass implements partial loading of a graph using a s3 uri
281 def _readBytes(self, start: int, stop: int) -> bytes:
282 args = {}
283 # minus 1 in the stop range, because this header is inclusive rather
284 # than standard python where the end point is generally exclusive
285 args["Range"] = f"bytes={start}-{stop-1}"
286 try:
287 response = self.uriObject.client.get_object(Bucket=self.uriObject.netloc,
288 Key=self.uriObject.relativeToPathRoot,
289 **args)
290 except (self.uriObject.client.exceptions.NoSuchKey,
291 self.uriObject.client.exceptions.NoSuchBucket) as err:
292 raise FileNotFoundError(f"No such resource: {self.uriObject}") from err
293 body = response["Body"].read()
294 response["Body"].close()
295 return body
298@register_helper(ButlerFileURI)
299class FileLoadHelper(DefaultLoadHelper):
300 # This subclass implements partial loading of a graph using a file uri
301 def _readBytes(self, start: int, stop: int) -> bytes:
302 if not hasattr(self, 'fileHandle'):
303 self.fileHandle = open(self.uriObject.ospath, 'rb')
304 self.fileHandle.seek(start)
305 return self.fileHandle.read(stop-start)
307 def close(self):
308 if hasattr(self, 'fileHandle'):
309 self.fileHandle.close()
312@register_helper(io.IOBase) # type: ignore
313class OpenFileHandleHelper(DefaultLoadHelper):
314 # This handler is special in that it does not get initialized with a
315 # ButlerURI, but an open file handle.
317 # Most everything stays the same, the variable is even stored as uriObject,
318 # because the interface needed for reading is the same. Unfortunately
319 # because we do not have Protocols yet, this can not be nicely expressed
320 # with typing.
322 # This helper does support partial loading
324 def __init__(self, uriObject: io.IO[bytes]):
325 # Explicitly annotate type and not infer from super
326 self.uriObject: io.IO[bytes]
327 super().__init__(uriObject)
328 # This differs from the default __init__ to force the io object
329 # back to the beginning so that in the case the entire file is to
330 # read in the file is not already in a partially read state.
331 self.uriObject.seek(0)
333 def _readBytes(self, start: int, stop: int) -> bytes:
334 self.uriObject.seek(start)
335 result = self.uriObject.read(stop-start)
336 return result
339@dataclass
340class LoadHelper:
341 """This is a helper class to assist with selecting the appropriate loader
342 and managing any contexts that may be needed.
344 Note
345 ----
346 This class may go away or be modified in the future if some of the
347 features of this module can be propagated to `~lsst.daf.butler.ButlerURI`.
349 This helper will raise a `ValueError` if the specified file does not appear
350 to be a valid `QuantumGraph` save file.
351 """
352 uri: ButlerURI
353 """ButlerURI object from which the `QuantumGraph` is to be loaded
354 """
355 def __enter__(self):
356 # Only one handler is registered for anything that is an instance of
357 # IOBase, so if any type is a subtype of that, set the key explicitly
358 # so the correct loader is found, otherwise index by the type
359 if isinstance(self.uri, io.IOBase):
360 key = io.IOBase
361 else:
362 key = type(self.uri)
363 self._loaded = HELPER_REGISTRY[key](self.uri)
364 return self._loaded
366 def __exit__(self, type, value, traceback):
367 self._loaded.close()