Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 29%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("LoadHelper", )
25from lsst.daf.butler import ButlerURI, Quantum
26from lsst.daf.butler.core._butlerUri.s3 import ButlerS3URI
27from lsst.daf.butler.core._butlerUri.file import ButlerFileURI
29from ..pipeline import TaskDef
30from .quantumNode import NodeId
32from dataclasses import dataclass
33import functools
34import io
35import json
36import lzma
37import pickle
38import struct
40from collections import defaultdict, UserDict
41from typing import (Optional, Iterable, DefaultDict, Set, Dict, TYPE_CHECKING, Tuple, Type, Union)
43if TYPE_CHECKING: 43 ↛ 44line 43 didn't jump to line 44, because the condition on line 43 was never true
44 from . import QuantumGraph
47# Create a custom dict that will return the desired default if a key is missing
48class RegistryDict(UserDict):
49 def __missing__(self, key):
50 return DefaultLoadHelper
53# Create a registry to hold all the load Helper classes
54HELPER_REGISTRY = RegistryDict()
57def register_helper(URICLass: Union[Type[ButlerURI], Type[io.IO[bytes]]]):
58 """Used to register classes as Load helpers
60 When decorating a class the parameter is the class of "handle type", i.e.
61 a ButlerURI type or open file handle that will be used to do the loading.
62 This is then associated with the decorated class such that when the
63 parameter type is used to load data, the appropriate helper to work with
64 that data type can be returned.
66 A decorator is used so that in theory someone could define another handler
67 in a different module and register it for use.
69 Parameters
70 ----------
71 URIClass : Type of `~lsst.daf.butler.ButlerURI` or `~io.IO` of bytes
72 type for which the decorated class should be mapped to
73 """
74 def wrapper(class_):
75 HELPER_REGISTRY[URICLass] = class_
76 return class_
77 return wrapper
80class DefaultLoadHelper:
81 """Default load helper for `QuantumGraph` save files
83 This class, and its subclasses, are used to unpack a quantum graph save
84 file. This file is a binary representation of the graph in a format that
85 allows individual nodes to be loaded without needing to load the entire
86 file.
88 This default implementation has the interface to load select nodes
89 from disk, but actually always loads the entire save file and simply
90 returns what nodes (or all) are requested. This is intended to serve for
91 all cases where there is a read method on the input parameter, but it is
92 unknown how to read select bytes of the stream. It is the responsibility of
93 sub classes to implement the method responsible for loading individual
94 bytes from the stream.
96 Parameters
97 ----------
98 uriObject : `~lsst.daf.butler.ButlerURI` or `io.IO` of bytes
99 This is the object that will be used to retrieve the raw bytes of the
100 save.
102 Raises
103 ------
104 ValueError
105 Raised if the specified file contains the wrong file signature and is
106 not a `QuantumGraph` save
107 """
108 def __init__(self, uriObject: Union[ButlerURI, io.IO[bytes]]):
109 self.uriObject = uriObject
111 # The length of infoSize will either be a tuple with length 2,
112 # (version 1) which contains the lengths of 2 independent pickles,
113 # or a tuple of length 1 which contains the total length of the entire
114 # header information (minus the magic bytes and version bytes)
115 preambleSize, infoSize = self._readSizes()
117 # Recode the total header size
118 if self.save_version == 1:
119 self.headerSize = preambleSize + infoSize[0] + infoSize[1]
120 elif self.save_version == 2:
121 self.headerSize = preambleSize + infoSize[0]
122 else:
123 raise ValueError(f"Unable to load QuantumGraph with version {self.save_version}, "
124 "please try a newer version of the code.")
126 self._readByteMappings(preambleSize, self.headerSize, infoSize)
128 def _readSizes(self) -> Tuple[int, Tuple[int, ...]]:
129 # need to import here to avoid cyclic imports
130 from .graph import STRUCT_FMT_BASE, MAGIC_BYTES, STRUCT_FMT_STRING, SAVE_VERSION
131 # Read the first few bytes which correspond to the lengths of the
132 # magic identifier bytes, 2 byte version
133 # number and the two 8 bytes numbers that are the sizes of the byte
134 # maps
135 magicSize = len(MAGIC_BYTES)
137 # read in just the fmt base to determine the save version
138 fmtSize = struct.calcsize(STRUCT_FMT_BASE)
139 preambleSize = magicSize + fmtSize
141 headerBytes = self._readBytes(0, preambleSize)
142 magic = headerBytes[:magicSize]
143 versionBytes = headerBytes[magicSize:]
145 if magic != MAGIC_BYTES:
146 raise ValueError("This file does not appear to be a quantum graph save got magic bytes "
147 f"{magic}, expected {MAGIC_BYTES}")
149 # Turn they encode bytes back into a python int object
150 save_version, = struct.unpack(STRUCT_FMT_BASE, versionBytes)
152 if save_version > SAVE_VERSION:
153 raise RuntimeError(f"The version of this save file is {save_version}, but this version of"
154 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}")
156 # read in the next bits
157 fmtString = STRUCT_FMT_STRING[save_version]
158 infoSize = struct.calcsize(fmtString)
159 infoBytes = self._readBytes(preambleSize, preambleSize+infoSize)
160 infoUnpack = struct.unpack(fmtString, infoBytes)
162 preambleSize += infoSize
164 # Store the save version, so future read codes can make use of any
165 # format changes to the save protocol
166 self.save_version = save_version
168 return preambleSize, infoUnpack
170 def _readByteMappings(self, preambleSize: int, headerSize: int, infoSize: Tuple[int, ...]) -> None:
171 # Take the header size explicitly so subclasses can modify before
172 # This task is called
174 # read the bytes of taskDef bytes and nodes skipping the size bytes
175 headerMaps = self._readBytes(preambleSize, headerSize)
177 if self.save_version == 1:
178 taskDefSize, _ = infoSize
180 # read the map of taskDef bytes back in skipping the size bytes
181 self.taskDefMap = pickle.loads(headerMaps[:taskDefSize])
183 # read back in the graph id
184 self._buildId = self.taskDefMap['__GraphBuildID']
186 # read the map of the node objects back in skipping bytes
187 # corresponding to the taskDef byte map
188 self.map = pickle.loads(headerMaps[taskDefSize:])
190 # There is no metadata for old versions
191 self.metadata = None
192 elif self.save_version == 2:
193 uncompressedHeaderMap = lzma.decompress(headerMaps)
194 header = json.loads(uncompressedHeaderMap)
195 self.taskDefMap = header['TaskDefs']
196 self._buildId = header['GraphBuildID']
197 self.map = dict(header['Nodes'])
198 self.metadata = header['Metadata']
199 else:
200 raise ValueError(f"Unable to load QuantumGraph with version {self.save_version}, "
201 "please try a newer version of the code.")
203 def load(self, nodes: Optional[Iterable[int]] = None, graphID: Optional[str] = None) -> QuantumGraph:
204 """Loads in the specified nodes from the graph
206 Load in the `QuantumGraph` containing only the nodes specified in the
207 ``nodes`` parameter from the graph specified at object creation. If
208 ``nodes`` is None (the default) the whole graph is loaded.
210 Parameters
211 ----------
212 nodes : `Iterable` of `int` or `None`
213 The nodes to load from the graph, loads all if value is None
214 (the default)
215 graphID : `str` or `None`
216 If specified this ID is verified against the loaded graph prior to
217 loading any Nodes. This defaults to None in which case no
218 validation is done.
220 Returns
221 -------
222 graph : `QuantumGraph`
223 The loaded `QuantumGraph` object
225 Raises
226 ------
227 ValueError
228 Raised if one or more of the nodes requested is not in the
229 `QuantumGraph` or if graphID parameter does not match the graph
230 being loaded.
231 """
232 # need to import here to avoid cyclic imports
233 from . import QuantumGraph
234 if graphID is not None and self._buildId != graphID:
235 raise ValueError('graphID does not match that of the graph being loaded')
236 # Read in specified nodes, or all the nodes
237 if nodes is None:
238 nodes = list(self.map.keys())
239 # if all nodes are to be read, force the reader from the base class
240 # that will read all they bytes in one go
241 _readBytes = functools.partial(DefaultLoadHelper._readBytes, self)
242 else:
243 # only some bytes are being read using the reader specialized for
244 # this class
245 # create a set to ensure nodes are only loaded once
246 nodes = set(nodes)
247 # verify that all nodes requested are in the graph
248 remainder = nodes - self.map.keys()
249 if remainder:
250 raise ValueError("Nodes {remainder} were requested, but could not be found in the input "
251 "graph")
252 _readBytes = self._readBytes
253 # create a container for loaded data
254 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
255 quantumToNodeId: Dict[Quantum, NodeId] = {}
256 loadedTaskDef = {}
257 # loop over the nodes specified above
258 for node in nodes:
259 # Get the bytes to read from the map
260 if self.save_version == 1:
261 start, stop = self.map[node]
262 else:
263 start, stop = self.map[node]['bytes']
264 start += self.headerSize
265 stop += self.headerSize
267 # read the specified bytes, will be overloaded by subclasses
268 # bytes are compressed, so decompress them
269 dump = lzma.decompress(_readBytes(start, stop))
271 # reconstruct node
272 qNode = pickle.loads(dump)
274 # read the saved node, name. If it has been loaded, attach it, if
275 # not read in the taskDef first, and then load it
276 nodeTask = qNode.taskDef
277 if nodeTask not in loadedTaskDef:
278 # Get the byte ranges corresponding to this taskDef
279 if self.save_version == 1:
280 start, stop = self.taskDefMap[nodeTask]
281 else:
282 start, stop = self.taskDefMap[nodeTask]['bytes']
283 start += self.headerSize
284 stop += self.headerSize
286 # load the taskDef, this method call will be overloaded by
287 # subclasses.
288 # bytes are compressed, so decompress them
289 taskDef = pickle.loads(lzma.decompress(_readBytes(start, stop)))
290 loadedTaskDef[nodeTask] = taskDef
291 # Explicitly overload the "frozen-ness" of nodes to attach the
292 # taskDef back into the un-persisted node
293 object.__setattr__(qNode, 'taskDef', loadedTaskDef[nodeTask])
294 quanta[qNode.taskDef].add(qNode.quantum)
296 # record the node for later processing
297 quantumToNodeId[qNode.quantum] = qNode.nodeId
299 # construct an empty new QuantumGraph object, and run the associated
300 # creation method with the un-persisted data
301 qGraph = object.__new__(QuantumGraph)
302 qGraph._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=self._buildId,
303 metadata=self.metadata)
304 return qGraph
306 def _readBytes(self, start: int, stop: int) -> bytes:
307 """Loads the specified byte range from the ButlerURI object
309 In the base class, this actually will read all the bytes into a buffer
310 from the specified ButlerURI object. Then for each method call will
311 return the requested byte range. This is the most flexible
312 implementation, as no special read is required. This will not give a
313 speed up with any sub graph reads though.
314 """
315 if not hasattr(self, 'buffer'):
316 self.buffer = self.uriObject.read()
317 return self.buffer[start:stop]
319 def close(self):
320 """Cleans up an instance if needed. Base class does nothing
321 """
322 pass
325@register_helper(ButlerS3URI)
326class S3LoadHelper(DefaultLoadHelper):
327 # This subclass implements partial loading of a graph using a s3 uri
328 def _readBytes(self, start: int, stop: int) -> bytes:
329 args = {}
330 # minus 1 in the stop range, because this header is inclusive rather
331 # than standard python where the end point is generally exclusive
332 args["Range"] = f"bytes={start}-{stop-1}"
333 try:
334 response = self.uriObject.client.get_object(Bucket=self.uriObject.netloc,
335 Key=self.uriObject.relativeToPathRoot,
336 **args)
337 except (self.uriObject.client.exceptions.NoSuchKey,
338 self.uriObject.client.exceptions.NoSuchBucket) as err:
339 raise FileNotFoundError(f"No such resource: {self.uriObject}") from err
340 body = response["Body"].read()
341 response["Body"].close()
342 return body
345@register_helper(ButlerFileURI)
346class FileLoadHelper(DefaultLoadHelper):
347 # This subclass implements partial loading of a graph using a file uri
348 def _readBytes(self, start: int, stop: int) -> bytes:
349 if not hasattr(self, 'fileHandle'):
350 self.fileHandle = open(self.uriObject.ospath, 'rb')
351 self.fileHandle.seek(start)
352 return self.fileHandle.read(stop-start)
354 def close(self):
355 if hasattr(self, 'fileHandle'):
356 self.fileHandle.close()
359@register_helper(io.IOBase) # type: ignore
360class OpenFileHandleHelper(DefaultLoadHelper):
361 # This handler is special in that it does not get initialized with a
362 # ButlerURI, but an open file handle.
364 # Most everything stays the same, the variable is even stored as uriObject,
365 # because the interface needed for reading is the same. Unfortunately
366 # because we do not have Protocols yet, this can not be nicely expressed
367 # with typing.
369 # This helper does support partial loading
371 def __init__(self, uriObject: io.IO[bytes]):
372 # Explicitly annotate type and not infer from super
373 self.uriObject: io.IO[bytes]
374 super().__init__(uriObject)
375 # This differs from the default __init__ to force the io object
376 # back to the beginning so that in the case the entire file is to
377 # read in the file is not already in a partially read state.
378 self.uriObject.seek(0)
380 def _readBytes(self, start: int, stop: int) -> bytes:
381 self.uriObject.seek(start)
382 result = self.uriObject.read(stop-start)
383 return result
386@dataclass
387class LoadHelper:
388 """This is a helper class to assist with selecting the appropriate loader
389 and managing any contexts that may be needed.
391 Note
392 ----
393 This class may go away or be modified in the future if some of the
394 features of this module can be propagated to `~lsst.daf.butler.ButlerURI`.
396 This helper will raise a `ValueError` if the specified file does not appear
397 to be a valid `QuantumGraph` save file.
398 """
399 uri: ButlerURI
400 """ButlerURI object from which the `QuantumGraph` is to be loaded
401 """
402 def __enter__(self):
403 # Only one handler is registered for anything that is an instance of
404 # IOBase, so if any type is a subtype of that, set the key explicitly
405 # so the correct loader is found, otherwise index by the type
406 if isinstance(self.uri, io.IOBase):
407 key = io.IOBase
408 else:
409 key = type(self.uri)
410 self._loaded = HELPER_REGISTRY[key](self.uri)
411 return self._loaded
413 def __exit__(self, type, value, traceback):
414 self._loaded.close()