Coverage for python/lsst/pipe/base/graph/_loadHelpers.py: 27%
83 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-04 02:56 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-04 02:56 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("LoadHelper",)
31import struct
32from collections.abc import Iterable
33from contextlib import AbstractContextManager, ExitStack
34from dataclasses import dataclass
35from io import BufferedRandom, BytesIO
36from types import TracebackType
37from typing import TYPE_CHECKING, BinaryIO
38from uuid import UUID
40from lsst.daf.butler import DimensionUniverse
41from lsst.daf.butler.persistence_context import PersistenceContextVars
42from lsst.resources import ResourceHandleProtocol, ResourcePath
44if TYPE_CHECKING:
45 from ._versionDeserializers import DeserializerBase
46 from .graph import QuantumGraph
49@dataclass
50class LoadHelper(AbstractContextManager["LoadHelper"]):
51 """Helper class to assist with selecting the appropriate loader
52 and managing any contexts that may be needed.
54 This helper will raise a `ValueError` if the specified file does not appear
55 to be a valid `QuantumGraph` save file.
56 """
58 uri: ResourcePath | BinaryIO
59 """ResourcePath object from which the `QuantumGraph` is to be loaded
60 """
61 minimumVersion: int
62 """
63 Minimum version of a save file to load. Set to -1 to load all
64 versions. Older versions may need to be loaded, and re-saved
65 to upgrade them to the latest format before they can be used in
66 production.
67 """
69 def __post_init__(self) -> None:
70 self._resourceHandle: ResourceHandleProtocol | None = None
71 self._exitStack = ExitStack()
73 def _initialize(self) -> None:
74 # need to import here to avoid cyclic imports
75 from ._versionDeserializers import DESERIALIZER_MAP
76 from .graph import MAGIC_BYTES, STRUCT_FMT_BASE
78 # Read the first few bytes which correspond to the magic identifier
79 # bytes, and save version
80 magicSize = len(MAGIC_BYTES)
81 # read in just the fmt base to determine the save version
82 fmtSize = struct.calcsize(STRUCT_FMT_BASE)
83 preambleSize = magicSize + fmtSize
84 headerBytes = self._readBytes(0, preambleSize)
85 magic = headerBytes[:magicSize]
86 versionBytes = headerBytes[magicSize:]
88 save_version = self._validateSave(magic, versionBytes)
90 # select the appropriate deserializer for this save version
91 deserializerClass = DESERIALIZER_MAP[save_version]
93 # read in the bytes corresponding to the mappings and initialize the
94 # deserializer. This will be the bytes that describe the following
95 # byte boundaries of the header info
96 sizeBytes = self._readBytes(preambleSize, preambleSize + deserializerClass.structSize)
97 # DeserializerBase subclasses are required to have the same constructor
98 # signature as the base class itself, but there is no way to express
99 # this in the type system, so we just tell MyPy to ignore it.
100 self.deserializer: DeserializerBase = deserializerClass(preambleSize, sizeBytes)
101 # get the header byte range for later reading
102 self.headerBytesRange = (preambleSize + deserializerClass.structSize, self.deserializer.headerSize)
104 def _validateSave(self, magic: bytes, versionBytes: bytes) -> int:
105 """Implement validation on input file, prior to attempting to load it
107 Parameters
108 ----------
109 magic : `bytes`
110 The first few bytes of the file, used to verify it is a
111 `QuantumGraph` save file.
112 versionBytes : `bytes`
113 The next few bytes from the beginning of the file, used to parse
114 which version of the `QuantumGraph` file the save corresponds to.
116 Returns
117 -------
118 save_version : `int`
119 The save version parsed from the supplied bytes.
121 Raises
122 ------
123 ValueError
124 Raised if the specified file contains the wrong file signature and
125 is not a `QuantumGraph` save, or if the graph save version is
126 below the minimum specified version.
127 """
128 from .graph import MAGIC_BYTES, SAVE_VERSION, STRUCT_FMT_BASE
130 if magic != MAGIC_BYTES:
131 raise ValueError(
132 "This file does not appear to be a quantum graph save got magic bytes "
133 f"{magic!r}, expected {MAGIC_BYTES!r}"
134 )
136 # unpack the save version bytes and verify it is a version that this
137 # code can understand
138 (save_version,) = struct.unpack(STRUCT_FMT_BASE, versionBytes)
139 # loads can sometimes trigger upgrades in format to a latest version,
140 # in which case accessory code might not match the upgraded graph.
141 # I.E. switching from old node number to UUID. This clause necessitates
142 # that users specifically interact with older graph versions and verify
143 # everything happens appropriately.
144 if save_version < self.minimumVersion:
145 raise ValueError(
146 f"The loaded QuantumGraph is version {save_version}, and the minimum "
147 f"version specified is {self.minimumVersion}. Please re-run this method "
148 "with a lower minimum version, then re-save the graph to automatically upgrade"
149 "to the newest version. Older versions may not work correctly with newer code"
150 )
152 if save_version > SAVE_VERSION:
153 raise ValueError(
154 f"The version of this save file is {save_version}, but this version of"
155 f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}"
156 )
157 return save_version
159 def load(
160 self,
161 universe: DimensionUniverse | None = None,
162 nodes: Iterable[UUID | str] | None = None,
163 graphID: str | None = None,
164 ) -> QuantumGraph:
165 """Load in the specified nodes from the graph.
167 Load in the `QuantumGraph` containing only the nodes specified in the
168 ``nodes`` parameter from the graph specified at object creation. If
169 ``nodes`` is None (the default) the whole graph is loaded.
171 Parameters
172 ----------
173 universe : `~lsst.daf.butler.DimensionUniverse` or None
174 DimensionUniverse instance, not used by the method itself but
175 needed to ensure that registry data structures are initialized.
176 The universe saved with the graph is used, but if one is passed
177 it will be used to validate the compatibility with the loaded
178 graph universe.
179 nodes : `~collections.abc.Iterable` of `UUID` or `str`; or `None`
180 The nodes to load from the graph, loads all if value is None
181 (the default).
182 graphID : `str` or `None`
183 If specified this ID is verified against the loaded graph prior to
184 loading any Nodes. This defaults to None in which case no
185 validation is done.
187 Returns
188 -------
189 graph : `QuantumGraph`
190 The loaded `QuantumGraph` object.
192 Raises
193 ------
194 ValueError
195 Raised if one or more of the nodes requested is not in the
196 `QuantumGraph` or if graphID parameter does not match the graph
197 being loaded.
198 RuntimeError
199 Raised if supplied `~lsst.daf.butler.DimensionUniverse` is not
200 compatible with the `~lsst.daf.butler.DimensionUniverse` saved in
201 the graph. Raised if the method was not called from within a
202 context block.
203 """
204 if self._resourceHandle is None:
205 raise RuntimeError("Load can only be used within a context manager")
207 headerInfo = self.deserializer.readHeaderInfo(self._readBytes(*self.headerBytesRange))
208 # verify this is the expected graph
209 if graphID is not None and headerInfo._buildId != graphID:
210 raise ValueError("graphID does not match that of the graph being loaded")
211 # Read in specified nodes, or all the nodes
212 nodeSet: set[UUID]
213 if nodes is None:
214 nodeSet = set(headerInfo.map.keys())
215 else:
216 # only some bytes are being read using the reader specialized for
217 # this class
218 # create a set to ensure nodes are only loaded once
219 nodeSet = {UUID(n) if isinstance(n, str) else n for n in nodes}
220 # verify that all nodes requested are in the graph
221 remainder = nodeSet - headerInfo.map.keys()
222 if remainder:
223 raise ValueError(
224 f"Nodes {remainder} were requested, but could not be found in the input graph"
225 )
226 _readBytes = self._readBytes
227 if universe is None:
228 universe = headerInfo.universe
229 # use the daf butler context vars to aid in ensuring deduplication in
230 # object instantiation.
231 runner = PersistenceContextVars()
232 graph = runner.run(self.deserializer.constructGraph, nodeSet, _readBytes, universe)
233 return graph
235 def _readBytes(self, start: int, stop: int) -> bytes:
236 """Load the specified byte range from the ResourcePath object
238 Parameters
239 ----------
240 start : `int`
241 The beginning byte location to read
242 stop : `int`
243 The end byte location to read
245 Returns
246 -------
247 result : `bytes`
248 The byte range specified from the
249 `~lsst.resources.ResourceHandleProtocol`.
251 Raises
252 ------
253 RuntimeError
254 Raise if the method was not called from within a context block.
255 """
256 if self._resourceHandle is None:
257 raise RuntimeError("_readBytes must be called from within a context block")
258 self._resourceHandle.seek(start)
259 return self._resourceHandle.read(stop - start)
261 def __enter__(self) -> "LoadHelper":
262 if isinstance(self.uri, BinaryIO | BytesIO | BufferedRandom):
263 self._resourceHandle = self.uri
264 else:
265 self._resourceHandle = self._exitStack.enter_context(self.uri.open("rb"))
266 self._initialize()
267 return self
269 def __exit__(
270 self,
271 type: type[BaseException] | None,
272 value: BaseException | None,
273 traceback: TracebackType | None,
274 ) -> None:
275 assert self._resourceHandle is not None
276 self._exitStack.close()
277 self._resourceHandle = None
279 def readHeader(self) -> str | None:
280 with self as handle:
281 result = handle.deserializer.unpackHeader(self._readBytes(*self.headerBytesRange))
282 return result