Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 40%
130 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-20 02:51 -0700
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-20 02:51 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("LoadHelper",)
25import functools
26import io
27import struct
28from dataclasses import dataclass
29from types import TracebackType
30from typing import (
31 TYPE_CHECKING,
32 Any,
33 BinaryIO,
34 Callable,
35 ContextManager,
36 Iterable,
37 Optional,
38 Set,
39 Type,
40 TypeVar,
41 Union,
42)
43from uuid import UUID
45from lsst.daf.butler import DimensionUniverse
46from lsst.resources import ResourcePath
47from lsst.resources.file import FileResourcePath
48from lsst.resources.s3 import S3ResourcePath
50if TYPE_CHECKING: 50 ↛ 51line 50 didn't jump to line 51, because the condition on line 50 was never true
51 from ._versionDeserializers import DeserializerBase
52 from .graph import QuantumGraph
55_T = TypeVar("_T")
58# Create a custom dict that will return the desired default if a key is missing
59class RegistryDict(dict):
60 def __missing__(self, key: Any) -> Type[DefaultLoadHelper]:
61 return DefaultLoadHelper
64# Create a registry to hold all the load Helper classes
65HELPER_REGISTRY = RegistryDict()
68def register_helper(URIClass: Union[Type[ResourcePath], Type[BinaryIO]]) -> Callable[[_T], _T]:
69 """Used to register classes as Load helpers
71 When decorating a class the parameter is the class of "handle type", i.e.
72 a ResourcePath type or open file handle that will be used to do the
73 loading. This is then associated with the decorated class such that when
74 the parameter type is used to load data, the appropriate helper to work
75 with that data type can be returned.
77 A decorator is used so that in theory someone could define another handler
78 in a different module and register it for use.
80 Parameters
81 ----------
82 URIClass : Type of `~lsst.resources.ResourcePath` or `~IO` of bytes
83 type for which the decorated class should be mapped to
84 """
86 def wrapper(class_: _T) -> _T:
87 HELPER_REGISTRY[URIClass] = class_
88 return class_
90 return wrapper
93class DefaultLoadHelper:
94 """Default load helper for `QuantumGraph` save files
96 This class, and its subclasses, are used to unpack a quantum graph save
97 file. This file is a binary representation of the graph in a format that
98 allows individual nodes to be loaded without needing to load the entire
99 file.
101 This default implementation has the interface to load select nodes
102 from disk, but actually always loads the entire save file and simply
103 returns what nodes (or all) are requested. This is intended to serve for
104 all cases where there is a read method on the input parameter, but it is
105 unknown how to read select bytes of the stream. It is the responsibility of
106 sub classes to implement the method responsible for loading individual
107 bytes from the stream.
109 Parameters
110 ----------
111 uriObject : `~lsst.resources.ResourcePath` or `IO` of bytes
112 This is the object that will be used to retrieve the raw bytes of the
113 save.
114 minimumVersion : `int`
115 Minimum version of a save file to load. Set to -1 to load all
116 versions. Older versions may need to be loaded, and re-saved
117 to upgrade them to the latest format. This upgrade may not happen
118 deterministically each time an older graph format is loaded. Because
119 of this behavior, the minimumVersion parameter, forces a user to
120 interact manually and take this into account before they can be used in
121 production.
123 Raises
124 ------
125 ValueError
126 Raised if the specified file contains the wrong file signature and is
127 not a `QuantumGraph` save, or if the graph save version is below the
128 minimum specified version.
129 """
131 def __init__(self, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int):
132 headerBytes = self.__setup_impl(uriObject, minimumVersion)
133 self.headerInfo = self.deserializer.readHeaderInfo(headerBytes)
135 def __setup_impl(self, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int) -> bytes:
136 self.uriObject = uriObject
137 # need to import here to avoid cyclic imports
138 from ._versionDeserializers import DESERIALIZER_MAP
139 from .graph import MAGIC_BYTES, SAVE_VERSION, STRUCT_FMT_BASE
141 # Read the first few bytes which correspond to the magic identifier
142 # bytes, and save version
143 magicSize = len(MAGIC_BYTES)
144 # read in just the fmt base to determine the save version
145 fmtSize = struct.calcsize(STRUCT_FMT_BASE)
146 preambleSize = magicSize + fmtSize
147 headerBytes = self._readBytes(0, preambleSize)
148 magic = headerBytes[:magicSize]
149 versionBytes = headerBytes[magicSize:]
151 if magic != MAGIC_BYTES:
152 raise ValueError(
153 "This file does not appear to be a quantum graph save got magic bytes "
154 f"{magic!r}, expected {MAGIC_BYTES!r}"
155 )
157 # unpack the save version bytes and verify it is a version that this
158 # code can understand
159 (save_version,) = struct.unpack(STRUCT_FMT_BASE, versionBytes)
160 # loads can sometimes trigger upgrades in format to a latest version,
161 # in which case accessory code might not match the upgraded graph.
162 # I.E. switching from old node number to UUID. This clause necessitates
163 # that users specifically interact with older graph versions and verify
164 # everything happens appropriately.
165 if save_version < minimumVersion:
166 raise ValueError(
167 f"The loaded QuantumGraph is version {save_version}, and the minimum "
168 f"version specified is {minimumVersion}. Please re-run this method "
169 "with a lower minimum version, then re-save the graph to automatically upgrade"
170 "to the newest version. Older versions may not work correctly with newer code"
171 )
173 if save_version > SAVE_VERSION:
174 raise RuntimeError(
175 f"The version of this save file is {save_version}, but this version of"
176 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}"
177 )
179 # select the appropriate deserializer for this save version
180 deserializerClass = DESERIALIZER_MAP[save_version]
182 # read in the bytes corresponding to the mappings and initialize the
183 # deserializer. This will be the bytes that describe the following
184 # byte boundaries of the header info
185 sizeBytes = self._readBytes(preambleSize, preambleSize + deserializerClass.structSize)
186 # DeserializerBase subclasses are required to have the same constructor
187 # signature as the base class itself, but there is no way to express
188 # this in the type system, so we just tell MyPy to ignore it.
189 self.deserializer: DeserializerBase = deserializerClass(preambleSize, sizeBytes) # type: ignore
191 # get the header info
192 headerBytes = self._readBytes(
193 preambleSize + deserializerClass.structSize, self.deserializer.headerSize
194 )
195 return headerBytes
197 @classmethod
198 def dumpHeader(cls, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int = 3) -> Optional[str]:
199 instance = cls.__new__(cls)
200 headerBytes = instance.__setup_impl(uriObject, minimumVersion)
201 header = instance.deserializer.unpackHeader(headerBytes)
202 instance.close()
203 return header
205 def load(
206 self,
207 universe: Optional[DimensionUniverse] = None,
208 nodes: Optional[Iterable[Union[UUID, str]]] = None,
209 graphID: Optional[str] = None,
210 ) -> QuantumGraph:
211 """Loads in the specified nodes from the graph
213 Load in the `QuantumGraph` containing only the nodes specified in the
214 ``nodes`` parameter from the graph specified at object creation. If
215 ``nodes`` is None (the default) the whole graph is loaded.
217 Parameters
218 ----------
219 universe: `~lsst.daf.butler.DimensionUniverse` or None
220 DimensionUniverse instance, not used by the method itself but
221 needed to ensure that registry data structures are initialized.
222 The universe saved with the graph is used, but if one is passed
223 it will be used to validate the compatibility with the loaded
224 graph universe.
225 nodes : `Iterable` of `UUID` or `str`; or `None`
226 The nodes to load from the graph, loads all if value is None
227 (the default)
228 graphID : `str` or `None`
229 If specified this ID is verified against the loaded graph prior to
230 loading any Nodes. This defaults to None in which case no
231 validation is done.
233 Returns
234 -------
235 graph : `QuantumGraph`
236 The loaded `QuantumGraph` object
238 Raises
239 ------
240 ValueError
241 Raised if one or more of the nodes requested is not in the
242 `QuantumGraph` or if graphID parameter does not match the graph
243 being loaded.
244 RuntimeError
245 Raise if Supplied DimensionUniverse is not compatible with the
246 DimensionUniverse saved in the graph
247 """
248 # verify this is the expected graph
249 if graphID is not None and self.headerInfo._buildId != graphID:
250 raise ValueError("graphID does not match that of the graph being loaded")
251 # Read in specified nodes, or all the nodes
252 nodeSet: Set[UUID]
253 if nodes is None:
254 nodeSet = set(self.headerInfo.map.keys())
255 # if all nodes are to be read, force the reader from the base class
256 # that will read all they bytes in one go
257 _readBytes: Callable[[int, int], bytes] = functools.partial(DefaultLoadHelper._readBytes, self)
258 else:
259 # only some bytes are being read using the reader specialized for
260 # this class
261 # create a set to ensure nodes are only loaded once
262 nodeSet = {UUID(n) if isinstance(n, str) else n for n in nodes}
263 # verify that all nodes requested are in the graph
264 remainder = nodeSet - self.headerInfo.map.keys()
265 if remainder:
266 raise ValueError(
267 f"Nodes {remainder} were requested, but could not be found in the input graph"
268 )
269 _readBytes = self._readBytes
270 return self.deserializer.constructGraph(nodeSet, _readBytes, universe)
272 def _readBytes(self, start: int, stop: int) -> bytes:
273 """Loads the specified byte range from the ResourcePath object
275 In the base class, this actually will read all the bytes into a buffer
276 from the specified ResourcePath object. Then for each method call will
277 return the requested byte range. This is the most flexible
278 implementation, as no special read is required. This will not give a
279 speed up with any sub graph reads though.
280 """
281 if not hasattr(self, "buffer"):
282 self.buffer = self.uriObject.read()
283 return self.buffer[start:stop]
285 def close(self) -> None:
286 """Cleans up an instance if needed. Base class does nothing"""
287 pass
290@register_helper(S3ResourcePath)
291class S3LoadHelper(DefaultLoadHelper):
292 # This subclass implements partial loading of a graph using a s3 uri
293 def _readBytes(self, start: int, stop: int) -> bytes:
294 args = {}
295 # minus 1 in the stop range, because this header is inclusive rather
296 # than standard python where the end point is generally exclusive
297 args["Range"] = f"bytes={start}-{stop-1}"
298 try:
299 response = self.uriObject.client.get_object(
300 Bucket=self.uriObject.netloc, Key=self.uriObject.relativeToPathRoot, **args
301 )
302 except (
303 self.uriObject.client.exceptions.NoSuchKey,
304 self.uriObject.client.exceptions.NoSuchBucket,
305 ) as err:
306 raise FileNotFoundError(f"No such resource: {self.uriObject}") from err
307 body = response["Body"].read()
308 response["Body"].close()
309 return body
311 uriObject: S3ResourcePath
314@register_helper(FileResourcePath)
315class FileLoadHelper(DefaultLoadHelper):
316 # This subclass implements partial loading of a graph using a file uri
317 def _readBytes(self, start: int, stop: int) -> bytes:
318 if not hasattr(self, "fileHandle"):
319 self.fileHandle = open(self.uriObject.ospath, "rb")
320 self.fileHandle.seek(start)
321 return self.fileHandle.read(stop - start)
323 def close(self) -> None:
324 if hasattr(self, "fileHandle"):
325 self.fileHandle.close()
327 uriObject: FileResourcePath
330@register_helper(BinaryIO)
331class OpenFileHandleHelper(DefaultLoadHelper):
332 # This handler is special in that it does not get initialized with a
333 # ResourcePath, but an open file handle.
335 # Most everything stays the same, the variable is even stored as uriObject,
336 # because the interface needed for reading is the same. Unfortunately
337 # because we do not have Protocols yet, this can not be nicely expressed
338 # with typing.
340 # This helper does support partial loading
342 def __init__(self, uriObject: BinaryIO, minimumVersion: int):
343 # Explicitly annotate type and not infer from super
344 self.uriObject: BinaryIO
345 super().__init__(uriObject, minimumVersion=minimumVersion)
346 # This differs from the default __init__ to force the io object
347 # back to the beginning so that in the case the entire file is to
348 # read in the file is not already in a partially read state.
349 self.uriObject.seek(0)
351 def _readBytes(self, start: int, stop: int) -> bytes:
352 self.uriObject.seek(start)
353 result = self.uriObject.read(stop - start)
354 return result
357@dataclass
358class LoadHelper(ContextManager[DefaultLoadHelper]):
359 """This is a helper class to assist with selecting the appropriate loader
360 and managing any contexts that may be needed.
362 Note
363 ----
364 This class may go away or be modified in the future if some of the
365 features of this module can be propagated to
366 `~lsst.resources.ResourcePath`.
368 This helper will raise a `ValueError` if the specified file does not appear
369 to be a valid `QuantumGraph` save file.
370 """
372 uri: Union[ResourcePath, BinaryIO]
373 """ResourcePath object from which the `QuantumGraph` is to be loaded
374 """
375 minimumVersion: int
376 """
377 Minimum version of a save file to load. Set to -1 to load all
378 versions. Older versions may need to be loaded, and re-saved
379 to upgrade them to the latest format before they can be used in
380 production.
381 """
383 def __enter__(self) -> DefaultLoadHelper:
384 # Only one handler is registered for anything that is an instance of
385 # IOBase, so if any type is a subtype of that, set the key explicitly
386 # so the correct loader is found, otherwise index by the type
387 self._loaded = self._determineLoader()(self.uri, self.minimumVersion)
388 return self._loaded
390 def __exit__(
391 self,
392 type: Optional[Type[BaseException]],
393 value: Optional[BaseException],
394 traceback: Optional[TracebackType],
395 ) -> None:
396 self._loaded.close()
398 def _determineLoader(self) -> Type[DefaultLoadHelper]:
399 key: Union[Type[ResourcePath], Type[BinaryIO]]
400 # Typing for file-like types is a mess; BinaryIO isn't actually
401 # a base class of what open(..., 'rb') returns, and MyPy claims
402 # that IOBase and BinaryIO actually have incompatible method
403 # signatures. IOBase *is* a base class of what open(..., 'rb')
404 # returns, so it's what we have to use at runtime.
405 if isinstance(self.uri, io.IOBase): # type: ignore
406 key = BinaryIO
407 else:
408 key = type(self.uri)
409 return HELPER_REGISTRY[key]
411 def readHeader(self) -> Optional[str]:
412 type_ = self._determineLoader()
413 return type_.dumpHeader(self.uri, self.minimumVersion)