Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 39%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23from uuid import UUID
25__all__ = ("LoadHelper",)
27import functools
28import io
29import struct
30from collections import UserDict
31from dataclasses import dataclass
32from typing import IO, TYPE_CHECKING, Iterable, Optional, Type, Union
34from lsst.daf.butler import DimensionUniverse
35from lsst.resources import ResourcePath
36from lsst.resources.file import FileResourcePath
37from lsst.resources.s3 import S3ResourcePath
39if TYPE_CHECKING: 39 ↛ 40line 39 didn't jump to line 40, because the condition on line 39 was never true
40 from . import QuantumGraph
43# Create a custom dict that will return the desired default if a key is missing
44class RegistryDict(UserDict):
45 def __missing__(self, key):
46 return DefaultLoadHelper
49# Create a registry to hold all the load Helper classes
50HELPER_REGISTRY = RegistryDict()
53def register_helper(URICLass: Union[Type[ResourcePath], Type[io.IO[bytes]]]):
54 """Used to register classes as Load helpers
56 When decorating a class the parameter is the class of "handle type", i.e.
57 a ResourcePath type or open file handle that will be used to do the
58 loading. This is then associated with the decorated class such that when
59 the parameter type is used to load data, the appropriate helper to work
60 with that data type can be returned.
62 A decorator is used so that in theory someone could define another handler
63 in a different module and register it for use.
65 Parameters
66 ----------
67 URIClass : Type of `~lsst.resources.ResourcePath` or `~io.IO` of bytes
68 type for which the decorated class should be mapped to
69 """
71 def wrapper(class_):
72 HELPER_REGISTRY[URICLass] = class_
73 return class_
75 return wrapper
78class DefaultLoadHelper:
79 """Default load helper for `QuantumGraph` save files
81 This class, and its subclasses, are used to unpack a quantum graph save
82 file. This file is a binary representation of the graph in a format that
83 allows individual nodes to be loaded without needing to load the entire
84 file.
86 This default implementation has the interface to load select nodes
87 from disk, but actually always loads the entire save file and simply
88 returns what nodes (or all) are requested. This is intended to serve for
89 all cases where there is a read method on the input parameter, but it is
90 unknown how to read select bytes of the stream. It is the responsibility of
91 sub classes to implement the method responsible for loading individual
92 bytes from the stream.
94 Parameters
95 ----------
96 uriObject : `~lsst.resources.ResourcePath` or `io.IO` of bytes
97 This is the object that will be used to retrieve the raw bytes of the
98 save.
99 minimumVersion : `int`
100 Minimum version of a save file to load. Set to -1 to load all
101 versions. Older versions may need to be loaded, and re-saved
102 to upgrade them to the latest format. This upgrade may not happen
103 deterministically each time an older graph format is loaded. Because
104 of this behavior, the minimumVersion parameter, forces a user to
105 interact manually and take this into account before they can be used in
106 production.
108 Raises
109 ------
110 ValueError
111 Raised if the specified file contains the wrong file signature and is
112 not a `QuantumGraph` save, or if the graph save version is below the
113 minimum specified version.
114 """
116 def __init__(self, uriObject: Union[ResourcePath, io.IO[bytes]], minimumVersion: int):
117 headerBytes = self.__setup_impl(uriObject, minimumVersion)
118 self.headerInfo = self.deserializer.readHeaderInfo(headerBytes)
120 def __setup_impl(self, uriObject: Union[ResourcePath, io.IO[bytes]], minimumVersion: int) -> bytes:
121 self.uriObject = uriObject
122 # need to import here to avoid cyclic imports
123 from ._versionDeserializers import DESERIALIZER_MAP
124 from .graph import MAGIC_BYTES, SAVE_VERSION, STRUCT_FMT_BASE
126 # Read the first few bytes which correspond to the magic identifier
127 # bytes, and save version
128 magicSize = len(MAGIC_BYTES)
129 # read in just the fmt base to determine the save version
130 fmtSize = struct.calcsize(STRUCT_FMT_BASE)
131 preambleSize = magicSize + fmtSize
132 headerBytes = self._readBytes(0, preambleSize)
133 magic = headerBytes[:magicSize]
134 versionBytes = headerBytes[magicSize:]
136 if magic != MAGIC_BYTES:
137 raise ValueError(
138 "This file does not appear to be a quantum graph save got magic bytes "
139 f"{magic}, expected {MAGIC_BYTES}"
140 )
142 # unpack the save version bytes and verify it is a version that this
143 # code can understand
144 (save_version,) = struct.unpack(STRUCT_FMT_BASE, versionBytes)
145 # loads can sometimes trigger upgrades in format to a latest version,
146 # in which case accessory code might not match the upgraded graph.
147 # I.E. switching from old node number to UUID. This clause necessitates
148 # that users specifically interact with older graph versions and verify
149 # everything happens appropriately.
150 if save_version < minimumVersion:
151 raise ValueError(
152 f"The loaded QuantumGraph is version {save_version}, and the minimum "
153 f"version specified is {minimumVersion}. Please re-run this method "
154 "with a lower minimum version, then re-save the graph to automatically upgrade"
155 "to the newest version. Older versions may not work correctly with newer code"
156 )
158 if save_version > SAVE_VERSION:
159 raise RuntimeError(
160 f"The version of this save file is {save_version}, but this version of"
161 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}"
162 )
164 # select the appropriate deserializer for this save version
165 deserializerClass = DESERIALIZER_MAP[save_version]
167 # read in the bytes corresponding to the mappings and initialize the
168 # deserializer. This will be the bytes that describe the following
169 # byte boundaries of the header info
170 sizeBytes = self._readBytes(preambleSize, preambleSize + deserializerClass.structSize)
171 self.deserializer = deserializerClass(preambleSize, sizeBytes)
173 # get the header info
174 headerBytes = self._readBytes(
175 preambleSize + deserializerClass.structSize, self.deserializer.headerSize
176 )
177 return headerBytes
179 @classmethod
180 def dumpHeader(
181 cls, uriObject: Union[ResourcePath, io.IO[bytes]], minimumVersion: int = 3
182 ) -> Optional[str]:
183 instance = cls.__new__(cls)
184 headerBytes = instance.__setup_impl(uriObject, minimumVersion)
185 header = instance.deserializer.unpackHeader(headerBytes)
186 instance.close()
187 return header
189 def load(
190 self,
191 universe: DimensionUniverse,
192 nodes: Optional[Iterable[Union[str, UUID]]] = None,
193 graphID: Optional[str] = None,
194 ) -> QuantumGraph:
195 """Loads in the specified nodes from the graph
197 Load in the `QuantumGraph` containing only the nodes specified in the
198 ``nodes`` parameter from the graph specified at object creation. If
199 ``nodes`` is None (the default) the whole graph is loaded.
201 Parameters
202 ----------
203 universe: `~lsst.daf.butler.DimensionUniverse`
204 DimensionUniverse instance, not used by the method itself but
205 needed to ensure that registry data structures are initialized.
206 nodes : `Iterable` of `int` or `None`
207 The nodes to load from the graph, loads all if value is None
208 (the default)
209 graphID : `str` or `None`
210 If specified this ID is verified against the loaded graph prior to
211 loading any Nodes. This defaults to None in which case no
212 validation is done.
214 Returns
215 -------
216 graph : `QuantumGraph`
217 The loaded `QuantumGraph` object
219 Raises
220 ------
221 ValueError
222 Raised if one or more of the nodes requested is not in the
223 `QuantumGraph` or if graphID parameter does not match the graph
224 being loaded.
225 """
226 # verify this is the expected graph
227 if graphID is not None and self.headerInfo._buildId != graphID:
228 raise ValueError("graphID does not match that of the graph being loaded")
229 # Read in specified nodes, or all the nodes
230 if nodes is None:
231 nodes = list(self.headerInfo.map.keys())
232 # if all nodes are to be read, force the reader from the base class
233 # that will read all they bytes in one go
234 _readBytes = functools.partial(DefaultLoadHelper._readBytes, self)
235 else:
236 # only some bytes are being read using the reader specialized for
237 # this class
238 # create a set to ensure nodes are only loaded once
239 nodes = {UUID(n) if isinstance(n, str) else n for n in nodes}
240 # verify that all nodes requested are in the graph
241 remainder = nodes - self.headerInfo.map.keys()
242 if remainder:
243 raise ValueError(
244 f"Nodes {remainder} were requested, but could not be found in the input graph"
245 )
246 _readBytes = self._readBytes
247 return self.deserializer.constructGraph(nodes, _readBytes, universe)
249 def _readBytes(self, start: int, stop: int) -> bytes:
250 """Loads the specified byte range from the ResourcePath object
252 In the base class, this actually will read all the bytes into a buffer
253 from the specified ResourcePath object. Then for each method call will
254 return the requested byte range. This is the most flexible
255 implementation, as no special read is required. This will not give a
256 speed up with any sub graph reads though.
257 """
258 if not hasattr(self, "buffer"):
259 self.buffer = self.uriObject.read()
260 return self.buffer[start:stop]
262 def close(self):
263 """Cleans up an instance if needed. Base class does nothing"""
264 pass
267@register_helper(S3ResourcePath)
268class S3LoadHelper(DefaultLoadHelper):
269 # This subclass implements partial loading of a graph using a s3 uri
270 def _readBytes(self, start: int, stop: int) -> bytes:
271 args = {}
272 # minus 1 in the stop range, because this header is inclusive rather
273 # than standard python where the end point is generally exclusive
274 args["Range"] = f"bytes={start}-{stop-1}"
275 try:
276 response = self.uriObject.client.get_object(
277 Bucket=self.uriObject.netloc, Key=self.uriObject.relativeToPathRoot, **args
278 )
279 except (
280 self.uriObject.client.exceptions.NoSuchKey,
281 self.uriObject.client.exceptions.NoSuchBucket,
282 ) as err:
283 raise FileNotFoundError(f"No such resource: {self.uriObject}") from err
284 body = response["Body"].read()
285 response["Body"].close()
286 return body
289@register_helper(FileResourcePath)
290class FileLoadHelper(DefaultLoadHelper):
291 # This subclass implements partial loading of a graph using a file uri
292 def _readBytes(self, start: int, stop: int) -> bytes:
293 if not hasattr(self, "fileHandle"):
294 self.fileHandle = open(self.uriObject.ospath, "rb")
295 self.fileHandle.seek(start)
296 return self.fileHandle.read(stop - start)
298 def close(self):
299 if hasattr(self, "fileHandle"):
300 self.fileHandle.close()
303@register_helper(io.IOBase) # type: ignore
304class OpenFileHandleHelper(DefaultLoadHelper):
305 # This handler is special in that it does not get initialized with a
306 # ResourcePath, but an open file handle.
308 # Most everything stays the same, the variable is even stored as uriObject,
309 # because the interface needed for reading is the same. Unfortunately
310 # because we do not have Protocols yet, this can not be nicely expressed
311 # with typing.
313 # This helper does support partial loading
315 def __init__(self, uriObject: io.IO[bytes], minimumVersion: int):
316 # Explicitly annotate type and not infer from super
317 self.uriObject: io.IO[bytes]
318 super().__init__(uriObject, minimumVersion=minimumVersion)
319 # This differs from the default __init__ to force the io object
320 # back to the beginning so that in the case the entire file is to
321 # read in the file is not already in a partially read state.
322 self.uriObject.seek(0)
324 def _readBytes(self, start: int, stop: int) -> bytes:
325 self.uriObject.seek(start)
326 result = self.uriObject.read(stop - start)
327 return result
330@dataclass
331class LoadHelper:
332 """This is a helper class to assist with selecting the appropriate loader
333 and managing any contexts that may be needed.
335 Note
336 ----
337 This class may go away or be modified in the future if some of the
338 features of this module can be propagated to
339 `~lsst.resources.ResourcePath`.
341 This helper will raise a `ValueError` if the specified file does not appear
342 to be a valid `QuantumGraph` save file.
343 """
345 uri: Union[ResourcePath, IO[bytes]]
346 """ResourcePath object from which the `QuantumGraph` is to be loaded
347 """
348 minimumVersion: int
349 """
350 Minimum version of a save file to load. Set to -1 to load all
351 versions. Older versions may need to be loaded, and re-saved
352 to upgrade them to the latest format before they can be used in
353 production.
354 """
356 def __enter__(self):
357 # Only one handler is registered for anything that is an instance of
358 # IOBase, so if any type is a subtype of that, set the key explicitly
359 # so the correct loader is found, otherwise index by the type
360 self._loaded = self._determineLoader()(self.uri, self.minimumVersion)
361 return self._loaded
363 def __exit__(self, type, value, traceback):
364 self._loaded.close()
366 def _determineLoader(self) -> Type[DefaultLoadHelper]:
367 if isinstance(self.uri, io.IOBase):
368 key = io.IOBase
369 else:
370 key = type(self.uri)
371 return HELPER_REGISTRY[key]
373 def readHeader(self) -> Optional[str]:
374 type_ = self._determineLoader()
375 return type_.dumpHeader(self.uri, self.minimumVersion)