Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 26%
82 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-19 04:01 -0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-19 04:01 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("LoadHelper",)
25import struct
26from contextlib import ExitStack
27from dataclasses import dataclass
28from io import BufferedRandom, BytesIO
29from types import TracebackType
30from typing import TYPE_CHECKING, BinaryIO, ContextManager, Iterable, Optional, Set, Type, Union
31from uuid import UUID
33from lsst.daf.butler import DimensionUniverse
34from lsst.resources import ResourceHandleProtocol, ResourcePath
36if TYPE_CHECKING: 36 ↛ 37line 36 didn't jump to line 37, because the condition on line 36 was never true
37 from ._versionDeserializers import DeserializerBase
38 from .graph import QuantumGraph
41@dataclass
42class LoadHelper(ContextManager["LoadHelper"]):
43 """This is a helper class to assist with selecting the appropriate loader
44 and managing any contexts that may be needed.
46 This helper will raise a `ValueError` if the specified file does not appear
47 to be a valid `QuantumGraph` save file.
48 """
50 uri: Union[ResourcePath, BinaryIO]
51 """ResourcePath object from which the `QuantumGraph` is to be loaded
52 """
53 minimumVersion: int
54 """
55 Minimum version of a save file to load. Set to -1 to load all
56 versions. Older versions may need to be loaded, and re-saved
57 to upgrade them to the latest format before they can be used in
58 production.
59 """
61 def __post_init__(self) -> None:
62 self._resourceHandle: Optional[ResourceHandleProtocol] = None
63 self._exitStack = ExitStack()
65 def _initialize(self) -> None:
66 # need to import here to avoid cyclic imports
67 from ._versionDeserializers import DESERIALIZER_MAP
68 from .graph import MAGIC_BYTES, STRUCT_FMT_BASE
70 # Read the first few bytes which correspond to the magic identifier
71 # bytes, and save version
72 magicSize = len(MAGIC_BYTES)
73 # read in just the fmt base to determine the save version
74 fmtSize = struct.calcsize(STRUCT_FMT_BASE)
75 preambleSize = magicSize + fmtSize
76 headerBytes = self._readBytes(0, preambleSize)
77 magic = headerBytes[:magicSize]
78 versionBytes = headerBytes[magicSize:]
80 save_version = self._validateSave(magic, versionBytes)
82 # select the appropriate deserializer for this save version
83 deserializerClass = DESERIALIZER_MAP[save_version]
85 # read in the bytes corresponding to the mappings and initialize the
86 # deserializer. This will be the bytes that describe the following
87 # byte boundaries of the header info
88 sizeBytes = self._readBytes(preambleSize, preambleSize + deserializerClass.structSize)
89 # DeserializerBase subclasses are required to have the same constructor
90 # signature as the base class itself, but there is no way to express
91 # this in the type system, so we just tell MyPy to ignore it.
92 self.deserializer: DeserializerBase = deserializerClass(preambleSize, sizeBytes)
93 # get the header byte range for later reading
94 self.headerBytesRange = (preambleSize + deserializerClass.structSize, self.deserializer.headerSize)
96 def _validateSave(self, magic: bytes, versionBytes: bytes) -> int:
97 """Implement validation on input file, prior to attempting to load it
99 Paramters
100 ---------
101 magic : `bytes`
102 The first few bytes of the file, used to verify it is a
103 QuantumGraph save file
104 versionBytes : `bytes`
105 The next few bytes from the beginning of the file, used to parse
106 which version of the QuantumGraph file the save corresponds to
108 Returns
109 -------
110 save_version : `int`
111 The save version parsed from the supplied bytes
113 Raises
114 ------
115 ValueError
116 Raised if the specified file contains the wrong file signature and
117 is not a `QuantumGraph` save, or if the graph save version is
118 below the minimum specified version.
119 """
120 from .graph import MAGIC_BYTES, SAVE_VERSION, STRUCT_FMT_BASE
122 if magic != MAGIC_BYTES:
123 raise ValueError(
124 "This file does not appear to be a quantum graph save got magic bytes "
125 f"{magic!r}, expected {MAGIC_BYTES!r}"
126 )
128 # unpack the save version bytes and verify it is a version that this
129 # code can understand
130 (save_version,) = struct.unpack(STRUCT_FMT_BASE, versionBytes)
131 # loads can sometimes trigger upgrades in format to a latest version,
132 # in which case accessory code might not match the upgraded graph.
133 # I.E. switching from old node number to UUID. This clause necessitates
134 # that users specifically interact with older graph versions and verify
135 # everything happens appropriately.
136 if save_version < self.minimumVersion:
137 raise ValueError(
138 f"The loaded QuantumGraph is version {save_version}, and the minimum "
139 f"version specified is {self.minimumVersion}. Please re-run this method "
140 "with a lower minimum version, then re-save the graph to automatically upgrade"
141 "to the newest version. Older versions may not work correctly with newer code"
142 )
144 if save_version > SAVE_VERSION:
145 raise ValueError(
146 f"The version of this save file is {save_version}, but this version of"
147 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}"
148 )
149 return save_version
151 def load(
152 self,
153 universe: Optional[DimensionUniverse] = None,
154 nodes: Optional[Iterable[Union[UUID, str]]] = None,
155 graphID: Optional[str] = None,
156 ) -> QuantumGraph:
157 """Loads in the specified nodes from the graph
159 Load in the `QuantumGraph` containing only the nodes specified in the
160 ``nodes`` parameter from the graph specified at object creation. If
161 ``nodes`` is None (the default) the whole graph is loaded.
163 Parameters
164 ----------
165 universe: `~lsst.daf.butler.DimensionUniverse` or None
166 DimensionUniverse instance, not used by the method itself but
167 needed to ensure that registry data structures are initialized.
168 The universe saved with the graph is used, but if one is passed
169 it will be used to validate the compatibility with the loaded
170 graph universe.
171 nodes : `Iterable` of `UUID` or `str`; or `None`
172 The nodes to load from the graph, loads all if value is None
173 (the default)
174 graphID : `str` or `None`
175 If specified this ID is verified against the loaded graph prior to
176 loading any Nodes. This defaults to None in which case no
177 validation is done.
179 Returns
180 -------
181 graph : `QuantumGraph`
182 The loaded `QuantumGraph` object
184 Raises
185 ------
186 ValueError
187 Raised if one or more of the nodes requested is not in the
188 `QuantumGraph` or if graphID parameter does not match the graph
189 being loaded.
190 RuntimeError
191 Raised if Supplied DimensionUniverse is not compatible with the
192 DimensionUniverse saved in the graph
193 Raised if the method was not called from within a context block
194 """
195 if self._resourceHandle is None:
196 raise RuntimeError("Load can only be used within a context manager")
198 headerInfo = self.deserializer.readHeaderInfo(self._readBytes(*self.headerBytesRange))
199 # verify this is the expected graph
200 if graphID is not None and headerInfo._buildId != graphID:
201 raise ValueError("graphID does not match that of the graph being loaded")
202 # Read in specified nodes, or all the nodes
203 nodeSet: Set[UUID]
204 if nodes is None:
205 nodeSet = set(headerInfo.map.keys())
206 else:
207 # only some bytes are being read using the reader specialized for
208 # this class
209 # create a set to ensure nodes are only loaded once
210 nodeSet = {UUID(n) if isinstance(n, str) else n for n in nodes}
211 # verify that all nodes requested are in the graph
212 remainder = nodeSet - headerInfo.map.keys()
213 if remainder:
214 raise ValueError(
215 f"Nodes {remainder} were requested, but could not be found in the input graph"
216 )
217 _readBytes = self._readBytes
218 if universe is None:
219 universe = headerInfo.universe
220 return self.deserializer.constructGraph(nodeSet, _readBytes, universe)
222 def _readBytes(self, start: int, stop: int) -> bytes:
223 """Load the specified byte range from the ResourcePath object
225 Parameters
226 ----------
227 start : `int`
228 The beginning byte location to read
229 end : `int`
230 The end byte location to read
232 Returns
233 -------
234 result : `bytes`
235 The byte range specified from the `ResourceHandle`
237 Raises
238 ------
239 RuntimeError
240 Raise if the method was not called from within a context block
241 """
242 if self._resourceHandle is None:
243 raise RuntimeError("_readBytes must be called from within a context block")
244 self._resourceHandle.seek(start)
245 return self._resourceHandle.read(stop - start)
247 def __enter__(self) -> "LoadHelper":
248 if isinstance(self.uri, (BinaryIO, BytesIO, BufferedRandom)):
249 self._resourceHandle = self.uri
250 else:
251 self._resourceHandle = self._exitStack.enter_context(self.uri.open("rb"))
252 self._initialize()
253 return self
255 def __exit__(
256 self,
257 type: Optional[Type[BaseException]],
258 value: Optional[BaseException],
259 traceback: Optional[TracebackType],
260 ) -> None:
261 assert self._resourceHandle is not None
262 self._exitStack.close()
263 self._resourceHandle = None
265 def readHeader(self) -> Optional[str]:
266 with self as handle:
267 result = handle.deserializer.unpackHeader(self._readBytes(*self.headerBytesRange))
268 return result