Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 40%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("LoadHelper",)
25import functools
26import io
27import struct
28from dataclasses import dataclass
29from types import TracebackType
30from typing import (
31 TYPE_CHECKING,
32 Any,
33 BinaryIO,
34 Callable,
35 ContextManager,
36 Iterable,
37 Optional,
38 Type,
39 TypeVar,
40 Union,
41)
42from uuid import UUID
44from lsst.daf.butler import DimensionUniverse
45from lsst.resources import ResourcePath
46from lsst.resources.file import FileResourcePath
47from lsst.resources.s3 import S3ResourcePath
49if TYPE_CHECKING: 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true
50 from ._versionDeserializers import DeserializerBase
51 from .graph import QuantumGraph
54_T = TypeVar("_T")
57# Create a custom dict that will return the desired default if a key is missing
58class RegistryDict(dict):
59 def __missing__(self, key: Any) -> Type[DefaultLoadHelper]:
60 return DefaultLoadHelper
63# Create a registry to hold all the load Helper classes
64HELPER_REGISTRY = RegistryDict()
67def register_helper(URIClass: Union[Type[ResourcePath], Type[BinaryIO]]) -> Callable[[_T], _T]:
68 """Used to register classes as Load helpers
70 When decorating a class the parameter is the class of "handle type", i.e.
71 a ResourcePath type or open file handle that will be used to do the
72 loading. This is then associated with the decorated class such that when
73 the parameter type is used to load data, the appropriate helper to work
74 with that data type can be returned.
76 A decorator is used so that in theory someone could define another handler
77 in a different module and register it for use.
79 Parameters
80 ----------
81 URIClass : Type of `~lsst.resources.ResourcePath` or `~IO` of bytes
82 type for which the decorated class should be mapped to
83 """
85 def wrapper(class_: _T) -> _T:
86 HELPER_REGISTRY[URIClass] = class_
87 return class_
89 return wrapper
92class DefaultLoadHelper:
93 """Default load helper for `QuantumGraph` save files
95 This class, and its subclasses, are used to unpack a quantum graph save
96 file. This file is a binary representation of the graph in a format that
97 allows individual nodes to be loaded without needing to load the entire
98 file.
100 This default implementation has the interface to load select nodes
101 from disk, but actually always loads the entire save file and simply
102 returns what nodes (or all) are requested. This is intended to serve for
103 all cases where there is a read method on the input parameter, but it is
104 unknown how to read select bytes of the stream. It is the responsibility of
105 sub classes to implement the method responsible for loading individual
106 bytes from the stream.
108 Parameters
109 ----------
110 uriObject : `~lsst.resources.ResourcePath` or `IO` of bytes
111 This is the object that will be used to retrieve the raw bytes of the
112 save.
113 minimumVersion : `int`
114 Minimum version of a save file to load. Set to -1 to load all
115 versions. Older versions may need to be loaded, and re-saved
116 to upgrade them to the latest format. This upgrade may not happen
117 deterministically each time an older graph format is loaded. Because
118 of this behavior, the minimumVersion parameter, forces a user to
119 interact manually and take this into account before they can be used in
120 production.
122 Raises
123 ------
124 ValueError
125 Raised if the specified file contains the wrong file signature and is
126 not a `QuantumGraph` save, or if the graph save version is below the
127 minimum specified version.
128 """
130 def __init__(self, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int):
131 headerBytes = self.__setup_impl(uriObject, minimumVersion)
132 self.headerInfo = self.deserializer.readHeaderInfo(headerBytes)
134 def __setup_impl(self, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int) -> bytes:
135 self.uriObject = uriObject
136 # need to import here to avoid cyclic imports
137 from ._versionDeserializers import DESERIALIZER_MAP
138 from .graph import MAGIC_BYTES, SAVE_VERSION, STRUCT_FMT_BASE
140 # Read the first few bytes which correspond to the magic identifier
141 # bytes, and save version
142 magicSize = len(MAGIC_BYTES)
143 # read in just the fmt base to determine the save version
144 fmtSize = struct.calcsize(STRUCT_FMT_BASE)
145 preambleSize = magicSize + fmtSize
146 headerBytes = self._readBytes(0, preambleSize)
147 magic = headerBytes[:magicSize]
148 versionBytes = headerBytes[magicSize:]
150 if magic != MAGIC_BYTES:
151 raise ValueError(
152 "This file does not appear to be a quantum graph save got magic bytes "
153 f"{magic!r}, expected {MAGIC_BYTES!r}"
154 )
156 # unpack the save version bytes and verify it is a version that this
157 # code can understand
158 (save_version,) = struct.unpack(STRUCT_FMT_BASE, versionBytes)
159 # loads can sometimes trigger upgrades in format to a latest version,
160 # in which case accessory code might not match the upgraded graph.
161 # I.E. switching from old node number to UUID. This clause necessitates
162 # that users specifically interact with older graph versions and verify
163 # everything happens appropriately.
164 if save_version < minimumVersion:
165 raise ValueError(
166 f"The loaded QuantumGraph is version {save_version}, and the minimum "
167 f"version specified is {minimumVersion}. Please re-run this method "
168 "with a lower minimum version, then re-save the graph to automatically upgrade"
169 "to the newest version. Older versions may not work correctly with newer code"
170 )
172 if save_version > SAVE_VERSION:
173 raise RuntimeError(
174 f"The version of this save file is {save_version}, but this version of"
175 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}"
176 )
178 # select the appropriate deserializer for this save version
179 deserializerClass = DESERIALIZER_MAP[save_version]
181 # read in the bytes corresponding to the mappings and initialize the
182 # deserializer. This will be the bytes that describe the following
183 # byte boundaries of the header info
184 sizeBytes = self._readBytes(preambleSize, preambleSize + deserializerClass.structSize)
185 # DeserializerBase subclasses are required to have the same constructor
186 # signature as the base class itself, but there is no way to express
187 # this in the type system, so we just tell MyPy to ignore it.
188 self.deserializer: DeserializerBase = deserializerClass(preambleSize, sizeBytes) # type: ignore
190 # get the header info
191 headerBytes = self._readBytes(
192 preambleSize + deserializerClass.structSize, self.deserializer.headerSize
193 )
194 return headerBytes
196 @classmethod
197 def dumpHeader(cls, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int = 3) -> Optional[str]:
198 instance = cls.__new__(cls)
199 headerBytes = instance.__setup_impl(uriObject, minimumVersion)
200 header = instance.deserializer.unpackHeader(headerBytes)
201 instance.close()
202 return header
204 def load(
205 self,
206 universe: DimensionUniverse,
207 nodes: Optional[Iterable[UUID]] = None,
208 graphID: Optional[str] = None,
209 ) -> QuantumGraph:
210 """Loads in the specified nodes from the graph
212 Load in the `QuantumGraph` containing only the nodes specified in the
213 ``nodes`` parameter from the graph specified at object creation. If
214 ``nodes`` is None (the default) the whole graph is loaded.
216 Parameters
217 ----------
218 universe: `~lsst.daf.butler.DimensionUniverse`
219 DimensionUniverse instance, not used by the method itself but
220 needed to ensure that registry data structures are initialized.
221 nodes : `Iterable` of `UUID` or `None`
222 The nodes to load from the graph, loads all if value is None
223 (the default)
224 graphID : `str` or `None`
225 If specified this ID is verified against the loaded graph prior to
226 loading any Nodes. This defaults to None in which case no
227 validation is done.
229 Returns
230 -------
231 graph : `QuantumGraph`
232 The loaded `QuantumGraph` object
234 Raises
235 ------
236 ValueError
237 Raised if one or more of the nodes requested is not in the
238 `QuantumGraph` or if graphID parameter does not match the graph
239 being loaded.
240 """
241 # verify this is the expected graph
242 if graphID is not None and self.headerInfo._buildId != graphID:
243 raise ValueError("graphID does not match that of the graph being loaded")
244 # Read in specified nodes, or all the nodes
245 if nodes is None:
246 nodes = set(self.headerInfo.map.keys())
247 # if all nodes are to be read, force the reader from the base class
248 # that will read all they bytes in one go
249 _readBytes: Callable[[int, int], bytes] = functools.partial(DefaultLoadHelper._readBytes, self)
250 else:
251 # only some bytes are being read using the reader specialized for
252 # this class
253 # create a set to ensure nodes are only loaded once
254 nodes = {UUID(n) if isinstance(n, str) else n for n in nodes}
255 # verify that all nodes requested are in the graph
256 remainder = nodes - self.headerInfo.map.keys()
257 if remainder:
258 raise ValueError(
259 f"Nodes {remainder} were requested, but could not be found in the input graph"
260 )
261 _readBytes = self._readBytes
262 return self.deserializer.constructGraph(nodes, _readBytes, universe)
264 def _readBytes(self, start: int, stop: int) -> bytes:
265 """Loads the specified byte range from the ResourcePath object
267 In the base class, this actually will read all the bytes into a buffer
268 from the specified ResourcePath object. Then for each method call will
269 return the requested byte range. This is the most flexible
270 implementation, as no special read is required. This will not give a
271 speed up with any sub graph reads though.
272 """
273 if not hasattr(self, "buffer"):
274 self.buffer = self.uriObject.read()
275 return self.buffer[start:stop]
277 def close(self) -> None:
278 """Cleans up an instance if needed. Base class does nothing"""
279 pass
282@register_helper(S3ResourcePath)
283class S3LoadHelper(DefaultLoadHelper):
284 # This subclass implements partial loading of a graph using a s3 uri
285 def _readBytes(self, start: int, stop: int) -> bytes:
286 args = {}
287 # minus 1 in the stop range, because this header is inclusive rather
288 # than standard python where the end point is generally exclusive
289 args["Range"] = f"bytes={start}-{stop-1}"
290 try:
291 response = self.uriObject.client.get_object(
292 Bucket=self.uriObject.netloc, Key=self.uriObject.relativeToPathRoot, **args
293 )
294 except (
295 self.uriObject.client.exceptions.NoSuchKey,
296 self.uriObject.client.exceptions.NoSuchBucket,
297 ) as err:
298 raise FileNotFoundError(f"No such resource: {self.uriObject}") from err
299 body = response["Body"].read()
300 response["Body"].close()
301 return body
303 uriObject: S3ResourcePath
306@register_helper(FileResourcePath)
307class FileLoadHelper(DefaultLoadHelper):
308 # This subclass implements partial loading of a graph using a file uri
309 def _readBytes(self, start: int, stop: int) -> bytes:
310 if not hasattr(self, "fileHandle"):
311 self.fileHandle = open(self.uriObject.ospath, "rb")
312 self.fileHandle.seek(start)
313 return self.fileHandle.read(stop - start)
315 def close(self) -> None:
316 if hasattr(self, "fileHandle"):
317 self.fileHandle.close()
319 uriObject: FileResourcePath
322@register_helper(BinaryIO)
323class OpenFileHandleHelper(DefaultLoadHelper):
324 # This handler is special in that it does not get initialized with a
325 # ResourcePath, but an open file handle.
327 # Most everything stays the same, the variable is even stored as uriObject,
328 # because the interface needed for reading is the same. Unfortunately
329 # because we do not have Protocols yet, this can not be nicely expressed
330 # with typing.
332 # This helper does support partial loading
334 def __init__(self, uriObject: BinaryIO, minimumVersion: int):
335 # Explicitly annotate type and not infer from super
336 self.uriObject: BinaryIO
337 super().__init__(uriObject, minimumVersion=minimumVersion)
338 # This differs from the default __init__ to force the io object
339 # back to the beginning so that in the case the entire file is to
340 # read in the file is not already in a partially read state.
341 self.uriObject.seek(0)
343 def _readBytes(self, start: int, stop: int) -> bytes:
344 self.uriObject.seek(start)
345 result = self.uriObject.read(stop - start)
346 return result
349@dataclass
350class LoadHelper(ContextManager[DefaultLoadHelper]):
351 """This is a helper class to assist with selecting the appropriate loader
352 and managing any contexts that may be needed.
354 Note
355 ----
356 This class may go away or be modified in the future if some of the
357 features of this module can be propagated to
358 `~lsst.resources.ResourcePath`.
360 This helper will raise a `ValueError` if the specified file does not appear
361 to be a valid `QuantumGraph` save file.
362 """
364 uri: Union[ResourcePath, BinaryIO]
365 """ResourcePath object from which the `QuantumGraph` is to be loaded
366 """
367 minimumVersion: int
368 """
369 Minimum version of a save file to load. Set to -1 to load all
370 versions. Older versions may need to be loaded, and re-saved
371 to upgrade them to the latest format before they can be used in
372 production.
373 """
375 def __enter__(self) -> DefaultLoadHelper:
376 # Only one handler is registered for anything that is an instance of
377 # IOBase, so if any type is a subtype of that, set the key explicitly
378 # so the correct loader is found, otherwise index by the type
379 self._loaded = self._determineLoader()(self.uri, self.minimumVersion)
380 return self._loaded
382 def __exit__(
383 self,
384 type: Optional[Type[BaseException]],
385 value: Optional[BaseException],
386 traceback: Optional[TracebackType],
387 ) -> None:
388 self._loaded.close()
390 def _determineLoader(self) -> Type[DefaultLoadHelper]:
391 key: Union[Type[ResourcePath], Type[BinaryIO]]
392 # Typing for file-like types is a mess; BinaryIO isn't actually
393 # a base class of what open(..., 'rb') returns, and MyPy claims
394 # that IOBase and BinaryIO actually have incompatible method
395 # signatures. IOBase *is* a base class of what open(..., 'rb')
396 # returns, so it's what we have to use at runtime.
397 if isinstance(self.uri, io.IOBase): # type: ignore
398 key = BinaryIO
399 else:
400 key = type(self.uri)
401 return HELPER_REGISTRY[key]
403 def readHeader(self) -> Optional[str]:
404 type_ = self._determineLoader()
405 return type_.dumpHeader(self.uri, self.minimumVersion)