Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 39%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
22from uuid import UUID
24__all__ = ("LoadHelper", )
26from lsst.resources import ResourcePath
27from lsst.resources.s3 import S3ResourcePath
28from lsst.resources.file import FileResourcePath
29from lsst.daf.butler import DimensionUniverse
32from dataclasses import dataclass
33import functools
34import io
35import struct
37from collections import UserDict
38from typing import Optional, Iterable, TYPE_CHECKING, Type, Union, IO
40if TYPE_CHECKING: 40 ↛ 41line 40 didn't jump to line 41, because the condition on line 40 was never true
41 from . import QuantumGraph
44# Create a custom dict that will return the desired default if a key is missing
45class RegistryDict(UserDict):
46 def __missing__(self, key):
47 return DefaultLoadHelper
50# Create a registry to hold all the load Helper classes
51HELPER_REGISTRY = RegistryDict()
54def register_helper(URICLass: Union[Type[ResourcePath], Type[io.IO[bytes]]]):
55 """Used to register classes as Load helpers
57 When decorating a class the parameter is the class of "handle type", i.e.
58 a ResourcePath type or open file handle that will be used to do the
59 loading. This is then associated with the decorated class such that when
60 the parameter type is used to load data, the appropriate helper to work
61 with that data type can be returned.
63 A decorator is used so that in theory someone could define another handler
64 in a different module and register it for use.
66 Parameters
67 ----------
68 URIClass : Type of `~lsst.resources.ResourcePath` or `~io.IO` of bytes
69 type for which the decorated class should be mapped to
70 """
71 def wrapper(class_):
72 HELPER_REGISTRY[URICLass] = class_
73 return class_
74 return wrapper
77class DefaultLoadHelper:
78 """Default load helper for `QuantumGraph` save files
80 This class, and its subclasses, are used to unpack a quantum graph save
81 file. This file is a binary representation of the graph in a format that
82 allows individual nodes to be loaded without needing to load the entire
83 file.
85 This default implementation has the interface to load select nodes
86 from disk, but actually always loads the entire save file and simply
87 returns what nodes (or all) are requested. This is intended to serve for
88 all cases where there is a read method on the input parameter, but it is
89 unknown how to read select bytes of the stream. It is the responsibility of
90 sub classes to implement the method responsible for loading individual
91 bytes from the stream.
93 Parameters
94 ----------
95 uriObject : `~lsst.resources.ResourcePath` or `io.IO` of bytes
96 This is the object that will be used to retrieve the raw bytes of the
97 save.
98 minimumVersion : `int`
99 Minimum version of a save file to load. Set to -1 to load all
100 versions. Older versions may need to be loaded, and re-saved
101 to upgrade them to the latest format. This upgrade may not happen
102 deterministically each time an older graph format is loaded. Because
103 of this behavior, the minimumVersion parameter, forces a user to
104 interact manually and take this into account before they can be used in
105 production.
107 Raises
108 ------
109 ValueError
110 Raised if the specified file contains the wrong file signature and is
111 not a `QuantumGraph` save, or if the graph save version is below the
112 minimum specified version.
113 """
114 def __init__(self, uriObject: Union[ResourcePath, io.IO[bytes]], minimumVersion: int):
115 headerBytes = self.__setup_impl(uriObject, minimumVersion)
116 self.headerInfo = self.deserializer.readHeaderInfo(headerBytes)
118 def __setup_impl(self, uriObject: Union[ResourcePath, io.IO[bytes]], minimumVersion: int) -> bytes:
119 self.uriObject = uriObject
120 # need to import here to avoid cyclic imports
121 from .graph import STRUCT_FMT_BASE, MAGIC_BYTES, SAVE_VERSION
122 from ._versionDeserializers import DESERIALIZER_MAP
124 # Read the first few bytes which correspond to the magic identifier
125 # bytes, and save version
126 magicSize = len(MAGIC_BYTES)
127 # read in just the fmt base to determine the save version
128 fmtSize = struct.calcsize(STRUCT_FMT_BASE)
129 preambleSize = magicSize + fmtSize
130 headerBytes = self._readBytes(0, preambleSize)
131 magic = headerBytes[:magicSize]
132 versionBytes = headerBytes[magicSize:]
134 if magic != MAGIC_BYTES:
135 raise ValueError("This file does not appear to be a quantum graph save got magic bytes "
136 f"{magic}, expected {MAGIC_BYTES}")
138 # unpack the save version bytes and verify it is a version that this
139 # code can understand
140 save_version, = struct.unpack(STRUCT_FMT_BASE, versionBytes)
141 # loads can sometimes trigger upgrades in format to a latest version,
142 # in which case accessory code might not match the upgraded graph.
143 # I.E. switching from old node number to UUID. This clause necessitates
144 # that users specifically interact with older graph versions and verify
145 # everything happens appropriately.
146 if save_version < minimumVersion:
147 raise ValueError(f"The loaded QuantumGraph is version {save_version}, and the minimum "
148 f"version specified is {minimumVersion}. Please re-run this method "
149 "with a lower minimum version, then re-save the graph to automatically upgrade"
150 "to the newest version. Older versions may not work correctly with newer code")
152 if save_version > SAVE_VERSION:
153 raise RuntimeError(f"The version of this save file is {save_version}, but this version of"
154 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}")
156 # select the appropriate deserializer for this save version
157 deserializerClass = DESERIALIZER_MAP[save_version]
159 # read in the bytes corresponding to the mappings and initialize the
160 # deserializer. This will be the bytes that describe the following
161 # byte boundaries of the header info
162 sizeBytes = self._readBytes(preambleSize, preambleSize+deserializerClass.structSize)
163 self.deserializer = deserializerClass(preambleSize, sizeBytes)
165 # get the header info
166 headerBytes = self._readBytes(preambleSize+deserializerClass.structSize,
167 self.deserializer.headerSize)
168 return headerBytes
170 @classmethod
171 def dumpHeader(cls, uriObject: Union[ResourcePath, io.IO[bytes]], minimumVersion: int = 3
172 ) -> Optional[str]:
173 instance = cls.__new__(cls)
174 headerBytes = instance.__setup_impl(uriObject, minimumVersion)
175 header = instance.deserializer.unpackHeader(headerBytes)
176 instance.close()
177 return header
179 def load(self, universe: DimensionUniverse, nodes: Optional[Iterable[Union[str, UUID]]] = None,
180 graphID: Optional[str] = None) -> QuantumGraph:
181 """Loads in the specified nodes from the graph
183 Load in the `QuantumGraph` containing only the nodes specified in the
184 ``nodes`` parameter from the graph specified at object creation. If
185 ``nodes`` is None (the default) the whole graph is loaded.
187 Parameters
188 ----------
189 universe: `~lsst.daf.butler.DimensionUniverse`
190 DimensionUniverse instance, not used by the method itself but
191 needed to ensure that registry data structures are initialized.
192 nodes : `Iterable` of `int` or `None`
193 The nodes to load from the graph, loads all if value is None
194 (the default)
195 graphID : `str` or `None`
196 If specified this ID is verified against the loaded graph prior to
197 loading any Nodes. This defaults to None in which case no
198 validation is done.
200 Returns
201 -------
202 graph : `QuantumGraph`
203 The loaded `QuantumGraph` object
205 Raises
206 ------
207 ValueError
208 Raised if one or more of the nodes requested is not in the
209 `QuantumGraph` or if graphID parameter does not match the graph
210 being loaded.
211 """
212 # verify this is the expected graph
213 if graphID is not None and self.headerInfo._buildId != graphID:
214 raise ValueError('graphID does not match that of the graph being loaded')
215 # Read in specified nodes, or all the nodes
216 if nodes is None:
217 nodes = list(self.headerInfo.map.keys())
218 # if all nodes are to be read, force the reader from the base class
219 # that will read all they bytes in one go
220 _readBytes = functools.partial(DefaultLoadHelper._readBytes, self)
221 else:
222 # only some bytes are being read using the reader specialized for
223 # this class
224 # create a set to ensure nodes are only loaded once
225 nodes = {UUID(n) if isinstance(n, str) else n for n in nodes}
226 # verify that all nodes requested are in the graph
227 remainder = nodes - self.headerInfo.map.keys()
228 if remainder:
229 raise ValueError(f"Nodes {remainder} were requested, but could not be found in the input "
230 "graph")
231 _readBytes = self._readBytes
232 return self.deserializer.constructGraph(nodes, _readBytes, universe)
234 def _readBytes(self, start: int, stop: int) -> bytes:
235 """Loads the specified byte range from the ResourcePath object
237 In the base class, this actually will read all the bytes into a buffer
238 from the specified ResourcePath object. Then for each method call will
239 return the requested byte range. This is the most flexible
240 implementation, as no special read is required. This will not give a
241 speed up with any sub graph reads though.
242 """
243 if not hasattr(self, 'buffer'):
244 self.buffer = self.uriObject.read()
245 return self.buffer[start:stop]
247 def close(self):
248 """Cleans up an instance if needed. Base class does nothing
249 """
250 pass
253@register_helper(S3ResourcePath)
254class S3LoadHelper(DefaultLoadHelper):
255 # This subclass implements partial loading of a graph using a s3 uri
256 def _readBytes(self, start: int, stop: int) -> bytes:
257 args = {}
258 # minus 1 in the stop range, because this header is inclusive rather
259 # than standard python where the end point is generally exclusive
260 args["Range"] = f"bytes={start}-{stop-1}"
261 try:
262 response = self.uriObject.client.get_object(Bucket=self.uriObject.netloc,
263 Key=self.uriObject.relativeToPathRoot,
264 **args)
265 except (self.uriObject.client.exceptions.NoSuchKey,
266 self.uriObject.client.exceptions.NoSuchBucket) as err:
267 raise FileNotFoundError(f"No such resource: {self.uriObject}") from err
268 body = response["Body"].read()
269 response["Body"].close()
270 return body
273@register_helper(FileResourcePath)
274class FileLoadHelper(DefaultLoadHelper):
275 # This subclass implements partial loading of a graph using a file uri
276 def _readBytes(self, start: int, stop: int) -> bytes:
277 if not hasattr(self, 'fileHandle'):
278 self.fileHandle = open(self.uriObject.ospath, 'rb')
279 self.fileHandle.seek(start)
280 return self.fileHandle.read(stop-start)
282 def close(self):
283 if hasattr(self, 'fileHandle'):
284 self.fileHandle.close()
287@register_helper(io.IOBase) # type: ignore
288class OpenFileHandleHelper(DefaultLoadHelper):
289 # This handler is special in that it does not get initialized with a
290 # ResourcePath, but an open file handle.
292 # Most everything stays the same, the variable is even stored as uriObject,
293 # because the interface needed for reading is the same. Unfortunately
294 # because we do not have Protocols yet, this can not be nicely expressed
295 # with typing.
297 # This helper does support partial loading
299 def __init__(self, uriObject: io.IO[bytes], minimumVersion: int):
300 # Explicitly annotate type and not infer from super
301 self.uriObject: io.IO[bytes]
302 super().__init__(uriObject, minimumVersion=minimumVersion)
303 # This differs from the default __init__ to force the io object
304 # back to the beginning so that in the case the entire file is to
305 # read in the file is not already in a partially read state.
306 self.uriObject.seek(0)
308 def _readBytes(self, start: int, stop: int) -> bytes:
309 self.uriObject.seek(start)
310 result = self.uriObject.read(stop-start)
311 return result
314@dataclass
315class LoadHelper:
316 """This is a helper class to assist with selecting the appropriate loader
317 and managing any contexts that may be needed.
319 Note
320 ----
321 This class may go away or be modified in the future if some of the
322 features of this module can be propagated to
323 `~lsst.resources.ResourcePath`.
325 This helper will raise a `ValueError` if the specified file does not appear
326 to be a valid `QuantumGraph` save file.
327 """
328 uri: Union[ResourcePath, IO[bytes]]
329 """ResourcePath object from which the `QuantumGraph` is to be loaded
330 """
331 minimumVersion: int
332 """
333 Minimum version of a save file to load. Set to -1 to load all
334 versions. Older versions may need to be loaded, and re-saved
335 to upgrade them to the latest format before they can be used in
336 production.
337 """
339 def __enter__(self):
340 # Only one handler is registered for anything that is an instance of
341 # IOBase, so if any type is a subtype of that, set the key explicitly
342 # so the correct loader is found, otherwise index by the type
343 self._loaded = self._determineLoader()(self.uri, self.minimumVersion)
344 return self._loaded
346 def __exit__(self, type, value, traceback):
347 self._loaded.close()
349 def _determineLoader(self) -> Type[DefaultLoadHelper]:
350 if isinstance(self.uri, io.IOBase):
351 key = io.IOBase
352 else:
353 key = type(self.uri)
354 return HELPER_REGISTRY[key]
356 def readHeader(self) -> Optional[str]:
357 type_ = self._determineLoader()
358 return type_.dumpHeader(self.uri, self.minimumVersion)