Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 26%
82 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-23 08:14 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-23 08:14 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("LoadHelper",)
25import struct
26from collections.abc import Iterable
27from contextlib import AbstractContextManager, ExitStack
28from dataclasses import dataclass
29from io import BufferedRandom, BytesIO
30from types import TracebackType
31from typing import TYPE_CHECKING, BinaryIO
32from uuid import UUID
34from lsst.daf.butler import DimensionUniverse, PersistenceContextVars
35from lsst.resources import ResourceHandleProtocol, ResourcePath
37if TYPE_CHECKING:
38 from ._versionDeserializers import DeserializerBase
39 from .graph import QuantumGraph
42@dataclass
43class LoadHelper(AbstractContextManager["LoadHelper"]):
44 """Helper class to assist with selecting the appropriate loader
45 and managing any contexts that may be needed.
47 This helper will raise a `ValueError` if the specified file does not appear
48 to be a valid `QuantumGraph` save file.
49 """
51 uri: ResourcePath | BinaryIO
52 """ResourcePath object from which the `QuantumGraph` is to be loaded
53 """
54 minimumVersion: int
55 """
56 Minimum version of a save file to load. Set to -1 to load all
57 versions. Older versions may need to be loaded, and re-saved
58 to upgrade them to the latest format before they can be used in
59 production.
60 """
62 def __post_init__(self) -> None:
63 self._resourceHandle: ResourceHandleProtocol | None = None
64 self._exitStack = ExitStack()
66 def _initialize(self) -> None:
67 # need to import here to avoid cyclic imports
68 from ._versionDeserializers import DESERIALIZER_MAP
69 from .graph import MAGIC_BYTES, STRUCT_FMT_BASE
71 # Read the first few bytes which correspond to the magic identifier
72 # bytes, and save version
73 magicSize = len(MAGIC_BYTES)
74 # read in just the fmt base to determine the save version
75 fmtSize = struct.calcsize(STRUCT_FMT_BASE)
76 preambleSize = magicSize + fmtSize
77 headerBytes = self._readBytes(0, preambleSize)
78 magic = headerBytes[:magicSize]
79 versionBytes = headerBytes[magicSize:]
81 save_version = self._validateSave(magic, versionBytes)
83 # select the appropriate deserializer for this save version
84 deserializerClass = DESERIALIZER_MAP[save_version]
86 # read in the bytes corresponding to the mappings and initialize the
87 # deserializer. This will be the bytes that describe the following
88 # byte boundaries of the header info
89 sizeBytes = self._readBytes(preambleSize, preambleSize + deserializerClass.structSize)
90 # DeserializerBase subclasses are required to have the same constructor
91 # signature as the base class itself, but there is no way to express
92 # this in the type system, so we just tell MyPy to ignore it.
93 self.deserializer: DeserializerBase = deserializerClass(preambleSize, sizeBytes)
94 # get the header byte range for later reading
95 self.headerBytesRange = (preambleSize + deserializerClass.structSize, self.deserializer.headerSize)
97 def _validateSave(self, magic: bytes, versionBytes: bytes) -> int:
98 """Implement validation on input file, prior to attempting to load it
100 Paramters
101 ---------
102 magic : `bytes`
103 The first few bytes of the file, used to verify it is a
104 `QuantumGraph` save file.
105 versionBytes : `bytes`
106 The next few bytes from the beginning of the file, used to parse
107 which version of the `QuantumGraph` file the save corresponds to.
109 Returns
110 -------
111 save_version : `int`
112 The save version parsed from the supplied bytes.
114 Raises
115 ------
116 ValueError
117 Raised if the specified file contains the wrong file signature and
118 is not a `QuantumGraph` save, or if the graph save version is
119 below the minimum specified version.
120 """
121 from .graph import MAGIC_BYTES, SAVE_VERSION, STRUCT_FMT_BASE
123 if magic != MAGIC_BYTES:
124 raise ValueError(
125 "This file does not appear to be a quantum graph save got magic bytes "
126 f"{magic!r}, expected {MAGIC_BYTES!r}"
127 )
129 # unpack the save version bytes and verify it is a version that this
130 # code can understand
131 (save_version,) = struct.unpack(STRUCT_FMT_BASE, versionBytes)
132 # loads can sometimes trigger upgrades in format to a latest version,
133 # in which case accessory code might not match the upgraded graph.
134 # I.E. switching from old node number to UUID. This clause necessitates
135 # that users specifically interact with older graph versions and verify
136 # everything happens appropriately.
137 if save_version < self.minimumVersion:
138 raise ValueError(
139 f"The loaded QuantumGraph is version {save_version}, and the minimum "
140 f"version specified is {self.minimumVersion}. Please re-run this method "
141 "with a lower minimum version, then re-save the graph to automatically upgrade"
142 "to the newest version. Older versions may not work correctly with newer code"
143 )
145 if save_version > SAVE_VERSION:
146 raise ValueError(
147 f"The version of this save file is {save_version}, but this version of"
148 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}"
149 )
150 return save_version
152 def load(
153 self,
154 universe: DimensionUniverse | None = None,
155 nodes: Iterable[UUID | str] | None = None,
156 graphID: str | None = None,
157 ) -> QuantumGraph:
158 """Load in the specified nodes from the graph.
160 Load in the `QuantumGraph` containing only the nodes specified in the
161 ``nodes`` parameter from the graph specified at object creation. If
162 ``nodes`` is None (the default) the whole graph is loaded.
164 Parameters
165 ----------
166 universe: `~lsst.daf.butler.DimensionUniverse` or None
167 DimensionUniverse instance, not used by the method itself but
168 needed to ensure that registry data structures are initialized.
169 The universe saved with the graph is used, but if one is passed
170 it will be used to validate the compatibility with the loaded
171 graph universe.
172 nodes : `~collections.abc.Iterable` of `UUID` or `str`; or `None`
173 The nodes to load from the graph, loads all if value is None
174 (the default)
175 graphID : `str` or `None`
176 If specified this ID is verified against the loaded graph prior to
177 loading any Nodes. This defaults to None in which case no
178 validation is done.
180 Returns
181 -------
182 graph : `QuantumGraph`
183 The loaded `QuantumGraph` object.
185 Raises
186 ------
187 ValueError
188 Raised if one or more of the nodes requested is not in the
189 `QuantumGraph` or if graphID parameter does not match the graph
190 being loaded.
191 RuntimeError
192 Raised if supplied `~lsst.daf.butler.DimensionUniverse` is not
193 compatible with the `~lsst.daf.butler.DimensionUniverse` saved in
194 the graph. Raised if the method was not called from within a
195 context block.
196 """
197 if self._resourceHandle is None:
198 raise RuntimeError("Load can only be used within a context manager")
200 headerInfo = self.deserializer.readHeaderInfo(self._readBytes(*self.headerBytesRange))
201 # verify this is the expected graph
202 if graphID is not None and headerInfo._buildId != graphID:
203 raise ValueError("graphID does not match that of the graph being loaded")
204 # Read in specified nodes, or all the nodes
205 nodeSet: set[UUID]
206 if nodes is None:
207 nodeSet = set(headerInfo.map.keys())
208 else:
209 # only some bytes are being read using the reader specialized for
210 # this class
211 # create a set to ensure nodes are only loaded once
212 nodeSet = {UUID(n) if isinstance(n, str) else n for n in nodes}
213 # verify that all nodes requested are in the graph
214 remainder = nodeSet - headerInfo.map.keys()
215 if remainder:
216 raise ValueError(
217 f"Nodes {remainder} were requested, but could not be found in the input graph"
218 )
219 _readBytes = self._readBytes
220 if universe is None:
221 universe = headerInfo.universe
222 # use the daf butler context vars to aid in ensuring deduplication in
223 # object instantiation.
224 runner = PersistenceContextVars()
225 graph = runner.run(self.deserializer.constructGraph, nodeSet, _readBytes, universe)
226 return graph
228 def _readBytes(self, start: int, stop: int) -> bytes:
229 """Load the specified byte range from the ResourcePath object
231 Parameters
232 ----------
233 start : `int`
234 The beginning byte location to read
235 end : `int`
236 The end byte location to read
238 Returns
239 -------
240 result : `bytes`
241 The byte range specified from the
242 `~lsst.resources.ResourceHandleProtocol`.
244 Raises
245 ------
246 RuntimeError
247 Raise if the method was not called from within a context block.
248 """
249 if self._resourceHandle is None:
250 raise RuntimeError("_readBytes must be called from within a context block")
251 self._resourceHandle.seek(start)
252 return self._resourceHandle.read(stop - start)
254 def __enter__(self) -> "LoadHelper":
255 if isinstance(self.uri, (BinaryIO, BytesIO, BufferedRandom)):
256 self._resourceHandle = self.uri
257 else:
258 self._resourceHandle = self._exitStack.enter_context(self.uri.open("rb"))
259 self._initialize()
260 return self
262 def __exit__(
263 self,
264 type: type[BaseException] | None,
265 value: BaseException | None,
266 traceback: TracebackType | None,
267 ) -> None:
268 assert self._resourceHandle is not None
269 self._exitStack.close()
270 self._resourceHandle = None
272 def readHeader(self) -> str | None:
273 with self as handle:
274 result = handle.deserializer.unpackHeader(self._readBytes(*self.headerBytesRange))
275 return result