Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 38%
132 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-12 02:06 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-12 02:06 -0800
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("LoadHelper",)
25import functools
26import io
27import struct
28from dataclasses import dataclass
29from types import TracebackType
30from typing import (
31 TYPE_CHECKING,
32 Any,
33 BinaryIO,
34 Callable,
35 ContextManager,
36 Iterable,
37 Optional,
38 Set,
39 Type,
40 TypeVar,
41 Union,
42)
43from uuid import UUID
45from lsst.daf.butler import DimensionUniverse
46from lsst.resources import ResourcePath
47from lsst.resources.file import FileResourcePath
48from lsst.resources.s3 import S3ResourcePath
50if TYPE_CHECKING: 50 ↛ 51line 50 didn't jump to line 51, because the condition on line 50 was never true
51 from ._versionDeserializers import DeserializerBase
52 from .graph import QuantumGraph
55_T = TypeVar("_T")
58# Create a custom dict that will return the desired default if a key is missing
59class RegistryDict(dict):
60 def __missing__(self, key: Any) -> Type[DefaultLoadHelper]:
61 return DefaultLoadHelper
64# Create a registry to hold all the load Helper classes
65HELPER_REGISTRY = RegistryDict()
68def register_helper(URIClass: Union[Type[ResourcePath], Type[BinaryIO]]) -> Callable[[_T], _T]:
69 """Used to register classes as Load helpers
71 When decorating a class the parameter is the class of "handle type", i.e.
72 a ResourcePath type or open file handle that will be used to do the
73 loading. This is then associated with the decorated class such that when
74 the parameter type is used to load data, the appropriate helper to work
75 with that data type can be returned.
77 A decorator is used so that in theory someone could define another handler
78 in a different module and register it for use.
80 Parameters
81 ----------
82 URIClass : Type of `~lsst.resources.ResourcePath` or `~IO` of bytes
83 type for which the decorated class should be mapped to
84 """
86 def wrapper(class_: _T) -> _T:
87 HELPER_REGISTRY[URIClass] = class_
88 return class_
90 return wrapper
93class DefaultLoadHelper:
94 """Default load helper for `QuantumGraph` save files
96 This class, and its subclasses, are used to unpack a quantum graph save
97 file. This file is a binary representation of the graph in a format that
98 allows individual nodes to be loaded without needing to load the entire
99 file.
101 This default implementation has the interface to load select nodes
102 from disk, but actually always loads the entire save file and simply
103 returns what nodes (or all) are requested. This is intended to serve for
104 all cases where there is a read method on the input parameter, but it is
105 unknown how to read select bytes of the stream. It is the responsibility of
106 sub classes to implement the method responsible for loading individual
107 bytes from the stream.
109 Parameters
110 ----------
111 uriObject : `~lsst.resources.ResourcePath` or `IO` of bytes
112 This is the object that will be used to retrieve the raw bytes of the
113 save.
114 minimumVersion : `int`
115 Minimum version of a save file to load. Set to -1 to load all
116 versions. Older versions may need to be loaded, and re-saved
117 to upgrade them to the latest format. This upgrade may not happen
118 deterministically each time an older graph format is loaded. Because
119 of this behavior, the minimumVersion parameter, forces a user to
120 interact manually and take this into account before they can be used in
121 production.
123 Raises
124 ------
125 ValueError
126 Raised if the specified file contains the wrong file signature and is
127 not a `QuantumGraph` save, or if the graph save version is below the
128 minimum specified version.
129 """
131 def __init__(self, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int):
132 headerBytes = self.__setup_impl(uriObject, minimumVersion)
133 self.headerInfo = self.deserializer.readHeaderInfo(headerBytes)
135 def __setup_impl(self, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int) -> bytes:
136 self.uriObject = uriObject
137 # need to import here to avoid cyclic imports
138 from ._versionDeserializers import DESERIALIZER_MAP
139 from .graph import MAGIC_BYTES, SAVE_VERSION, STRUCT_FMT_BASE
141 # Read the first few bytes which correspond to the magic identifier
142 # bytes, and save version
143 magicSize = len(MAGIC_BYTES)
144 # read in just the fmt base to determine the save version
145 fmtSize = struct.calcsize(STRUCT_FMT_BASE)
146 preambleSize = magicSize + fmtSize
147 headerBytes = self._readBytes(0, preambleSize)
148 magic = headerBytes[:magicSize]
149 versionBytes = headerBytes[magicSize:]
151 if magic != MAGIC_BYTES:
152 raise ValueError(
153 "This file does not appear to be a quantum graph save got magic bytes "
154 f"{magic!r}, expected {MAGIC_BYTES!r}"
155 )
157 # unpack the save version bytes and verify it is a version that this
158 # code can understand
159 (save_version,) = struct.unpack(STRUCT_FMT_BASE, versionBytes)
160 # loads can sometimes trigger upgrades in format to a latest version,
161 # in which case accessory code might not match the upgraded graph.
162 # I.E. switching from old node number to UUID. This clause necessitates
163 # that users specifically interact with older graph versions and verify
164 # everything happens appropriately.
165 if save_version < minimumVersion:
166 raise ValueError(
167 f"The loaded QuantumGraph is version {save_version}, and the minimum "
168 f"version specified is {minimumVersion}. Please re-run this method "
169 "with a lower minimum version, then re-save the graph to automatically upgrade"
170 "to the newest version. Older versions may not work correctly with newer code"
171 )
173 if save_version > SAVE_VERSION:
174 raise RuntimeError(
175 f"The version of this save file is {save_version}, but this version of"
176 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}"
177 )
179 # select the appropriate deserializer for this save version
180 deserializerClass = DESERIALIZER_MAP[save_version]
182 # read in the bytes corresponding to the mappings and initialize the
183 # deserializer. This will be the bytes that describe the following
184 # byte boundaries of the header info
185 sizeBytes = self._readBytes(preambleSize, preambleSize + deserializerClass.structSize)
186 # DeserializerBase subclasses are required to have the same constructor
187 # signature as the base class itself, but there is no way to express
188 # this in the type system, so we just tell MyPy to ignore it.
189 self.deserializer: DeserializerBase = deserializerClass(preambleSize, sizeBytes) # type: ignore
191 # get the header info
192 headerBytes = self._readBytes(
193 preambleSize + deserializerClass.structSize, self.deserializer.headerSize
194 )
195 return headerBytes
197 @classmethod
198 def dumpHeader(cls, uriObject: Union[ResourcePath, BinaryIO], minimumVersion: int = 3) -> Optional[str]:
199 instance = cls.__new__(cls)
200 headerBytes = instance.__setup_impl(uriObject, minimumVersion)
201 header = instance.deserializer.unpackHeader(headerBytes)
202 instance.close()
203 return header
205 def load(
206 self,
207 universe: Optional[DimensionUniverse] = None,
208 nodes: Optional[Iterable[Union[UUID, str]]] = None,
209 graphID: Optional[str] = None,
210 ) -> QuantumGraph:
211 """Loads in the specified nodes from the graph
213 Load in the `QuantumGraph` containing only the nodes specified in the
214 ``nodes`` parameter from the graph specified at object creation. If
215 ``nodes`` is None (the default) the whole graph is loaded.
217 Parameters
218 ----------
219 universe: `~lsst.daf.butler.DimensionUniverse` or None
220 DimensionUniverse instance, not used by the method itself but
221 needed to ensure that registry data structures are initialized.
222 The universe saved with the graph is used, but if one is passed
223 it will be used to validate the compatibility with the loaded
224 graph universe.
225 nodes : `Iterable` of `UUID` or `str`; or `None`
226 The nodes to load from the graph, loads all if value is None
227 (the default)
228 graphID : `str` or `None`
229 If specified this ID is verified against the loaded graph prior to
230 loading any Nodes. This defaults to None in which case no
231 validation is done.
233 Returns
234 -------
235 graph : `QuantumGraph`
236 The loaded `QuantumGraph` object
238 Raises
239 ------
240 ValueError
241 Raised if one or more of the nodes requested is not in the
242 `QuantumGraph` or if graphID parameter does not match the graph
243 being loaded.
244 RuntimeError
245 Raise if Supplied DimensionUniverse is not compatible with the
246 DimensionUniverse saved in the graph
247 """
248 # verify this is the expected graph
249 if graphID is not None and self.headerInfo._buildId != graphID:
250 raise ValueError("graphID does not match that of the graph being loaded")
251 # Read in specified nodes, or all the nodes
252 nodeSet: Set[UUID]
253 if nodes is None:
254 nodeSet = set(self.headerInfo.map.keys())
255 # if all nodes are to be read, force the reader from the base class
256 # that will read all they bytes in one go
257 _readBytes: Callable[[int, int], bytes] = functools.partial(DefaultLoadHelper._readBytes, self)
258 else:
259 # only some bytes are being read using the reader specialized for
260 # this class
261 # create a set to ensure nodes are only loaded once
262 nodeSet = {UUID(n) if isinstance(n, str) else n for n in nodes}
263 # verify that all nodes requested are in the graph
264 remainder = nodeSet - self.headerInfo.map.keys()
265 if remainder:
266 raise ValueError(
267 f"Nodes {remainder} were requested, but could not be found in the input graph"
268 )
269 _readBytes = self._readBytes
270 if universe is None:
271 universe = self.headerInfo.universe
272 return self.deserializer.constructGraph(nodeSet, _readBytes, universe)
274 def _readBytes(self, start: int, stop: int) -> bytes:
275 """Loads the specified byte range from the ResourcePath object
277 In the base class, this actually will read all the bytes into a buffer
278 from the specified ResourcePath object. Then for each method call will
279 return the requested byte range. This is the most flexible
280 implementation, as no special read is required. This will not give a
281 speed up with any sub graph reads though.
282 """
283 if not hasattr(self, "buffer"):
284 self.buffer = self.uriObject.read()
285 return self.buffer[start:stop]
287 def close(self) -> None:
288 """Cleans up an instance if needed. Base class does nothing"""
289 pass
292@register_helper(S3ResourcePath)
293class S3LoadHelper(DefaultLoadHelper):
294 # This subclass implements partial loading of a graph using a s3 uri
295 def _readBytes(self, start: int, stop: int) -> bytes:
296 args = {}
297 # minus 1 in the stop range, because this header is inclusive rather
298 # than standard python where the end point is generally exclusive
299 args["Range"] = f"bytes={start}-{stop-1}"
300 try:
301 response = self.uriObject.client.get_object(
302 Bucket=self.uriObject.netloc, Key=self.uriObject.relativeToPathRoot, **args
303 )
304 except (
305 self.uriObject.client.exceptions.NoSuchKey,
306 self.uriObject.client.exceptions.NoSuchBucket,
307 ) as err:
308 raise FileNotFoundError(f"No such resource: {self.uriObject}") from err
309 body = response["Body"].read()
310 response["Body"].close()
311 return body
313 uriObject: S3ResourcePath
316@register_helper(FileResourcePath)
317class FileLoadHelper(DefaultLoadHelper):
318 # This subclass implements partial loading of a graph using a file uri
319 def _readBytes(self, start: int, stop: int) -> bytes:
320 if not hasattr(self, "fileHandle"):
321 self.fileHandle = open(self.uriObject.ospath, "rb")
322 self.fileHandle.seek(start)
323 return self.fileHandle.read(stop - start)
325 def close(self) -> None:
326 if hasattr(self, "fileHandle"):
327 self.fileHandle.close()
329 uriObject: FileResourcePath
332@register_helper(BinaryIO)
333class OpenFileHandleHelper(DefaultLoadHelper):
334 # This handler is special in that it does not get initialized with a
335 # ResourcePath, but an open file handle.
337 # Most everything stays the same, the variable is even stored as uriObject,
338 # because the interface needed for reading is the same. Unfortunately
339 # because we do not have Protocols yet, this can not be nicely expressed
340 # with typing.
342 # This helper does support partial loading
344 def __init__(self, uriObject: BinaryIO, minimumVersion: int):
345 # Explicitly annotate type and not infer from super
346 self.uriObject: BinaryIO
347 super().__init__(uriObject, minimumVersion=minimumVersion)
348 # This differs from the default __init__ to force the io object
349 # back to the beginning so that in the case the entire file is to
350 # read in the file is not already in a partially read state.
351 self.uriObject.seek(0)
353 def _readBytes(self, start: int, stop: int) -> bytes:
354 self.uriObject.seek(start)
355 result = self.uriObject.read(stop - start)
356 return result
359@dataclass
360class LoadHelper(ContextManager[DefaultLoadHelper]):
361 """This is a helper class to assist with selecting the appropriate loader
362 and managing any contexts that may be needed.
364 Note
365 ----
366 This class may go away or be modified in the future if some of the
367 features of this module can be propagated to
368 `~lsst.resources.ResourcePath`.
370 This helper will raise a `ValueError` if the specified file does not appear
371 to be a valid `QuantumGraph` save file.
372 """
374 uri: Union[ResourcePath, BinaryIO]
375 """ResourcePath object from which the `QuantumGraph` is to be loaded
376 """
377 minimumVersion: int
378 """
379 Minimum version of a save file to load. Set to -1 to load all
380 versions. Older versions may need to be loaded, and re-saved
381 to upgrade them to the latest format before they can be used in
382 production.
383 """
385 def __enter__(self) -> DefaultLoadHelper:
386 # Only one handler is registered for anything that is an instance of
387 # IOBase, so if any type is a subtype of that, set the key explicitly
388 # so the correct loader is found, otherwise index by the type
389 self._loaded = self._determineLoader()(self.uri, self.minimumVersion)
390 return self._loaded
392 def __exit__(
393 self,
394 type: Optional[Type[BaseException]],
395 value: Optional[BaseException],
396 traceback: Optional[TracebackType],
397 ) -> None:
398 self._loaded.close()
400 def _determineLoader(self) -> Type[DefaultLoadHelper]:
401 key: Union[Type[ResourcePath], Type[BinaryIO]]
402 # Typing for file-like types is a mess; BinaryIO isn't actually
403 # a base class of what open(..., 'rb') returns, and MyPy claims
404 # that IOBase and BinaryIO actually have incompatible method
405 # signatures. IOBase *is* a base class of what open(..., 'rb')
406 # returns, so it's what we have to use at runtime.
407 if isinstance(self.uri, io.IOBase): # type: ignore
408 key = BinaryIO
409 else:
410 key = type(self.uri)
411 return HELPER_REGISTRY[key]
413 def readHeader(self) -> Optional[str]:
414 type_ = self._determineLoader()
415 return type_.dumpHeader(self.uri, self.minimumVersion)