Coverage for python/lsst/pipe/base/graph/_versionDeserializers.py: 30%
240 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-02 02:23 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-02 02:23 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("DESERIALIZER_MAP",)
25import json
26import lzma
27import pickle
28import struct
29import uuid
30from abc import ABC, abstractmethod
31from collections import defaultdict
32from dataclasses import dataclass
33from types import SimpleNamespace
34from typing import TYPE_CHECKING, Callable, ClassVar, DefaultDict, Dict, Optional, Set, Tuple, Type
36import networkx as nx
37from lsst.daf.butler import (
38 DimensionConfig,
39 DimensionRecord,
40 DimensionUniverse,
41 Quantum,
42 SerializedDimensionRecord,
43)
44from lsst.utils import doImportType
46from ..config import PipelineTaskConfig
47from ..pipeline import TaskDef
48from ..pipelineTask import PipelineTask
49from ._implDetails import DatasetTypeName, _DatasetTracker
50from .quantumNode import QuantumNode, SerializedQuantumNode
52if TYPE_CHECKING: 52 ↛ 53line 52 didn't jump to line 53, because the condition on line 52 was never true
53 from .graph import QuantumGraph
56class StructSizeDescriptor:
57 """This is basically a class level property. It exists to report the size
58 (number of bytes) of whatever the formatter string is for a deserializer
59 """
61 def __get__(self, inst: Optional[DeserializerBase], owner: Type[DeserializerBase]) -> int:
62 return struct.calcsize(owner.FMT_STRING())
65@dataclass
66class DeserializerBase(ABC):
67 @classmethod
68 @abstractmethod
69 def FMT_STRING(cls) -> str: # noqa: N805 # flake8 wants self
70 raise NotImplementedError("Base class does not implement this method")
72 structSize: ClassVar[StructSizeDescriptor]
74 preambleSize: int
75 sizeBytes: bytes
77 def __init_subclass__(cls) -> None:
78 # attach the size decriptor
79 cls.structSize = StructSizeDescriptor()
80 super().__init_subclass__()
82 def unpackHeader(self, rawHeader: bytes) -> Optional[str]:
83 """Transforms the raw bytes corresponding to the header of a save into
84 a string of the header information. Returns none if the save format has
85 no header string implementation (such as save format 1 that is all
86 pickle)
88 Parameters
89 ----------
90 rawheader : bytes
91 The bytes that are to be parsed into the header information. These
92 are the bytes after the preamble and structsize number of bytes
93 and before the headerSize bytes
94 """
95 raise NotImplementedError("Base class does not implement this method")
97 @property
98 def headerSize(self) -> int:
99 """Returns the number of bytes from the beginning of the file to the
100 end of the metadata.
101 """
102 raise NotImplementedError("Base class does not implement this method")
104 def readHeaderInfo(self, rawHeader: bytes) -> SimpleNamespace:
105 """Parse the supplied raw bytes into the header information and
106 byte ranges of specific TaskDefs and QuantumNodes
108 Parameters
109 ----------
110 rawheader : bytes
111 The bytes that are to be parsed into the header information. These
112 are the bytes after the preamble and structsize number of bytes
113 and before the headerSize bytes
114 """
115 raise NotImplementedError("Base class does not implement this method")
117 def constructGraph(
118 self,
119 nodes: set[uuid.UUID],
120 _readBytes: Callable[[int, int], bytes],
121 universe: Optional[DimensionUniverse] = None,
122 ) -> QuantumGraph:
123 """Constructs a graph from the deserialized information.
125 Parameters
126 ----------
127 nodes : `set` of `uuid.UUID`
128 The nodes to include in the graph
129 _readBytes : callable
130 A callable that can be used to read bytes from the file handle.
131 The callable will take two ints, start and stop, to use as the
132 numerical bounds to read and returns a byte stream.
133 universe : `~lsst.daf.butler.DimensionUniverse`
134 The singleton of all dimensions known to the middleware registry
135 """
136 raise NotImplementedError("Base class does not implement this method")
138 def description(self) -> str:
139 """Return the description of the serialized data format"""
140 raise NotImplementedError("Base class does not implement this method")
143Version1Description = """
144The save file starts with the first few bytes corresponding to the magic bytes
145in the QuantumGraph: `qgraph4\xf6\xe8\xa9`.
147The next few bytes are 2 big endian unsigned 64 bit integers.
149The first unsigned 64 bit integer corresponds to the number of bytes of a
150python mapping of TaskDef labels to the byte ranges in the save file where the
151definition can be loaded.
153The second unsigned 64 bit integer corrresponds to the number of bytes of a
154python mapping of QuantumGraph Node number to the byte ranges in the save file
155where the node can be loaded. The byte range is indexed starting after
156the `header` bytes of the magic bytes, size bytes, and bytes of the two
157mappings.
159Each of the above mappings are pickled and then lzma compressed, so to
160deserialize the bytes, first lzma decompression must be performed and the
161results passed to python pickle loader.
163As stated above, each map contains byte ranges of the corresponding
164datastructure. Theses bytes are also lzma compressed pickles, and should
165be deserialized in a similar manner. The byte range is indexed starting after
166the `header` bytes of the magic bytes, size bytes, and bytes of the two
167mappings.
169In addition to the the TaskDef byte locations, the TypeDef map also contains
170an additional key '__GraphBuildID'. The value associated with this is the
171unique id assigned to the graph at its creation time.
172"""
175@dataclass
176class DeserializerV1(DeserializerBase):
177 @classmethod
178 def FMT_STRING(cls) -> str:
179 return ">QQ"
181 def __post_init__(self) -> None:
182 self.taskDefMapSize, self.nodeMapSize = struct.unpack(self.FMT_STRING(), self.sizeBytes)
184 @property
185 def headerSize(self) -> int:
186 return self.preambleSize + self.structSize + self.taskDefMapSize + self.nodeMapSize
188 def readHeaderInfo(self, rawHeader: bytes) -> SimpleNamespace:
189 returnValue = SimpleNamespace()
190 returnValue.taskDefMap = pickle.loads(rawHeader[: self.taskDefMapSize])
191 returnValue._buildId = returnValue.taskDefMap["__GraphBuildID"]
192 returnValue.map = pickle.loads(rawHeader[self.taskDefMapSize :])
193 returnValue.metadata = None
194 self.returnValue = returnValue
195 return returnValue
197 def unpackHeader(self, rawHeader: bytes) -> Optional[str]:
198 return None
200 def constructGraph(
201 self,
202 nodes: set[uuid.UUID],
203 _readBytes: Callable[[int, int], bytes],
204 universe: Optional[DimensionUniverse] = None,
205 ) -> QuantumGraph:
206 # need to import here to avoid cyclic imports
207 from . import QuantumGraph
209 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
210 quantumToNodeId: Dict[Quantum, uuid.UUID] = {}
211 loadedTaskDef = {}
212 # loop over the nodes specified above
213 for node in nodes:
214 # Get the bytes to read from the map
215 start, stop = self.returnValue.map[node]
216 start += self.headerSize
217 stop += self.headerSize
219 # read the specified bytes, will be overloaded by subclasses
220 # bytes are compressed, so decompress them
221 dump = lzma.decompress(_readBytes(start, stop))
223 # reconstruct node
224 qNode = pickle.loads(dump)
225 object.__setattr__(qNode, "nodeId", uuid.uuid4())
227 # read the saved node, name. If it has been loaded, attach it, if
228 # not read in the taskDef first, and then load it
229 nodeTask = qNode.taskDef
230 if nodeTask not in loadedTaskDef:
231 # Get the byte ranges corresponding to this taskDef
232 start, stop = self.returnValue.taskDefMap[nodeTask]
233 start += self.headerSize
234 stop += self.headerSize
236 # load the taskDef, this method call will be overloaded by
237 # subclasses.
238 # bytes are compressed, so decompress them
239 taskDef = pickle.loads(lzma.decompress(_readBytes(start, stop)))
240 loadedTaskDef[nodeTask] = taskDef
241 # Explicitly overload the "frozen-ness" of nodes to attach the
242 # taskDef back into the un-persisted node
243 object.__setattr__(qNode, "taskDef", loadedTaskDef[nodeTask])
244 quanta[qNode.taskDef].add(qNode.quantum)
246 # record the node for later processing
247 quantumToNodeId[qNode.quantum] = qNode.nodeId
249 # construct an empty new QuantumGraph object, and run the associated
250 # creation method with the un-persisted data
251 qGraph = object.__new__(QuantumGraph)
252 qGraph._buildGraphs(
253 quanta,
254 _quantumToNodeId=quantumToNodeId,
255 _buildId=self.returnValue._buildId,
256 metadata=self.returnValue.metadata,
257 universe=universe,
258 )
259 return qGraph
261 def description(self) -> str:
262 return Version1Description
265Version2Description = """
266The save file starts with the first few bytes corresponding to the magic bytes
267in the QuantumGraph: `qgraph4\xf6\xe8\xa9`.
269The next few bytes are a big endian unsigned long long.
271The unsigned long long corresponds to the number of bytes of a python mapping
272of header information. This mapping is encoded into json and then lzma
273compressed, meaning the operations must be performed in the opposite order to
274deserialize.
276The json encoded header mapping contains 4 fields: TaskDefs, GraphBuildId,
277Nodes, and Metadata.
279The `TaskDefs` key corresponds to a value which is a mapping of Task label to
280task data. The task data is a mapping of key to value, where the only key is
281`bytes` and it corresponds to a tuple of a byte range of the start, stop
282bytes (indexed after all the header bytes)
284The `GraphBuildId` corresponds with a string that is the unique id assigned to
285this graph when it was created.
287The `Nodes` key is like the `TaskDefs` key except it corresponds to
288QuantumNodes instead of TaskDefs. Another important difference is that JSON
289formatting does not allow using numbers as keys, and this mapping is keyed by
290the node number. Thus it is stored in JSON as two equal length lists, the first
291being the keys, and the second the values associated with those keys.
293The `Metadata` key is a mapping of strings to associated values. This metadata
294may be anything that is important to be transported alongside the graph.
296As stated above, each map contains byte ranges of the corresponding
297datastructure. Theses bytes are also lzma compressed pickles, and should
298be deserialized in a similar manner.
299"""
302@dataclass
303class DeserializerV2(DeserializerBase):
304 @classmethod
305 def FMT_STRING(cls) -> str:
306 return ">Q"
308 def __post_init__(self) -> None:
309 (self.mapSize,) = struct.unpack(self.FMT_STRING(), self.sizeBytes)
311 @property
312 def headerSize(self) -> int:
313 return self.preambleSize + self.structSize + self.mapSize
315 def readHeaderInfo(self, rawHeader: bytes) -> SimpleNamespace:
316 uncompressedHeaderMap = self.unpackHeader(rawHeader)
317 if uncompressedHeaderMap is None:
318 raise ValueError(
319 "This error is not possible because self.unpackHeader cannot return None,"
320 " but is done to satisfy type checkers"
321 )
322 header = json.loads(uncompressedHeaderMap)
323 returnValue = SimpleNamespace()
324 returnValue.taskDefMap = header["TaskDefs"]
325 returnValue._buildId = header["GraphBuildID"]
326 returnValue.map = dict(header["Nodes"])
327 returnValue.metadata = header["Metadata"]
328 self.returnValue = returnValue
329 return returnValue
331 def unpackHeader(self, rawHeader: bytes) -> Optional[str]:
332 return lzma.decompress(rawHeader).decode()
334 def constructGraph(
335 self,
336 nodes: set[uuid.UUID],
337 _readBytes: Callable[[int, int], bytes],
338 universe: Optional[DimensionUniverse] = None,
339 ) -> QuantumGraph:
340 # need to import here to avoid cyclic imports
341 from . import QuantumGraph
343 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
344 quantumToNodeId: Dict[Quantum, uuid.UUID] = {}
345 loadedTaskDef = {}
346 # loop over the nodes specified above
347 for node in nodes:
348 # Get the bytes to read from the map
349 start, stop = self.returnValue.map[node]["bytes"]
350 start += self.headerSize
351 stop += self.headerSize
353 # read the specified bytes, will be overloaded by subclasses
354 # bytes are compressed, so decompress them
355 dump = lzma.decompress(_readBytes(start, stop))
357 # reconstruct node
358 qNode = pickle.loads(dump)
359 object.__setattr__(qNode, "nodeId", uuid.uuid4())
361 # read the saved node, name. If it has been loaded, attach it, if
362 # not read in the taskDef first, and then load it
363 nodeTask = qNode.taskDef
364 if nodeTask not in loadedTaskDef:
365 # Get the byte ranges corresponding to this taskDef
366 start, stop = self.returnValue.taskDefMap[nodeTask]["bytes"]
367 start += self.headerSize
368 stop += self.headerSize
370 # load the taskDef, this method call will be overloaded by
371 # subclasses.
372 # bytes are compressed, so decompress them
373 taskDef = pickle.loads(lzma.decompress(_readBytes(start, stop)))
374 loadedTaskDef[nodeTask] = taskDef
375 # Explicitly overload the "frozen-ness" of nodes to attach the
376 # taskDef back into the un-persisted node
377 object.__setattr__(qNode, "taskDef", loadedTaskDef[nodeTask])
378 quanta[qNode.taskDef].add(qNode.quantum)
380 # record the node for later processing
381 quantumToNodeId[qNode.quantum] = qNode.nodeId
383 # construct an empty new QuantumGraph object, and run the associated
384 # creation method with the un-persisted data
385 qGraph = object.__new__(QuantumGraph)
386 qGraph._buildGraphs(
387 quanta,
388 _quantumToNodeId=quantumToNodeId,
389 _buildId=self.returnValue._buildId,
390 metadata=self.returnValue.metadata,
391 universe=universe,
392 )
393 return qGraph
395 def description(self) -> str:
396 return Version2Description
399Version3Description = """
400The save file starts with the first few bytes corresponding to the magic bytes
401in the QuantumGraph: `qgraph4\xf6\xe8\xa9`.
403The next few bytes are a big endian unsigned long long.
405The unsigned long long corresponds to the number of bytes of a mapping
406of header information. This mapping is encoded into json and then lzma
407compressed, meaning the operations must be performed in the opposite order to
408deserialize.
410The json encoded header mapping contains 5 fields: GraphBuildId, TaskDefs,
411Nodes, Metadata, and DimensionRecords.
413The `GraphBuildId` key corresponds with a string that is the unique id assigned
414to this graph when it was created.
416The `TaskDefs` key corresponds to a value which is a mapping of Task label to
417task data. The task data is a mapping of key to value. The keys of this mapping
418are `bytes`, `inputs`, and `outputs`.
420The `TaskDefs` `bytes` key corresponds to a tuple of a byte range of the
421start, stop bytes (indexed after all the header bytes). This byte rage
422corresponds to a lzma compressed json mapping. This mapping has keys of
423`taskName`, corresponding to a fully qualified python class, `config` a
424pex_config string that is used to configure the class, and `label` which
425corresponds to a string that uniquely identifies the task within a given
426execution pipeline.
428The `TaskDefs` `inputs` key is associated with a list of tuples where each
429tuple is a label of a task that is considered coming before a given task, and
430the name of the dataset that is shared between the tasks (think node and edge
431in a graph sense).
433The `TaskDefs` `outputs` key is like inputs except the values in a list
434correspond to all the output connections of a task.
436The `Nodes` key is also a json mapping with keys corresponding to the UUIDs of
437QuantumNodes. The values associated with these keys is another mapping with
438the keys `bytes`, `inputs`, and `outputs`.
440`Nodes` key `bytes` corresponds to a tuple of a byte range of the start, stop
441bytes (indexed after all the header bytes). These bytes are a lzma compressed
442json mapping which contains many sub elements, this mapping will be referred to
443as the SerializedQuantumNode (related to the python class it corresponds to).
445SerializedQUantumNodes have 3 keys, `quantum` corresponding to a json mapping
446(described below) referred to as a SerializedQuantum, `taskLabel` a string
447which corresponds to a label in the `TaskDefs` mapping, and `nodeId.
449A SerializedQuantum has many keys; taskName, dataId, datasetTypeMapping,
450initInputs, inputs, outputs, dimensionRecords.
452like the `TaskDefs` key except it corresponds to
453QuantumNodes instead of TaskDefs, and the keys of the mappings are string
454representations of the UUIDs of the QuantumNodes.
456The `Metadata` key is a mapping of strings to associated values. This metadata
457may be anything that is important to be transported alongside the graph.
459As stated above, each map contains byte ranges of the corresponding
460datastructure. Theses bytes are also lzma compressed pickles, and should
461be deserialized in a similar manner.
462"""
465@dataclass
466class DeserializerV3(DeserializerBase):
467 @classmethod
468 def FMT_STRING(cls) -> str:
469 return ">Q"
471 def __post_init__(self) -> None:
472 self.infoSize: int
473 (self.infoSize,) = struct.unpack(self.FMT_STRING(), self.sizeBytes)
475 @property
476 def headerSize(self) -> int:
477 return self.preambleSize + self.structSize + self.infoSize
479 def readHeaderInfo(self, rawHeader: bytes) -> SimpleNamespace:
480 uncompressedinfoMap = self.unpackHeader(rawHeader)
481 assert uncompressedinfoMap is not None # for python typing, this variant can't be None
482 infoMap = json.loads(uncompressedinfoMap)
483 infoMappings = SimpleNamespace()
484 infoMappings.taskDefMap = infoMap["TaskDefs"]
485 infoMappings._buildId = infoMap["GraphBuildID"]
486 infoMappings.map = {uuid.UUID(k): v for k, v in infoMap["Nodes"]}
487 infoMappings.metadata = infoMap["Metadata"]
488 infoMappings.dimensionRecords = {}
489 for k, v in infoMap["DimensionRecords"].items():
490 infoMappings.dimensionRecords[int(k)] = SerializedDimensionRecord(**v)
491 # This is important to be a get call here, so that it supports versions
492 # of saved quantum graph that might not have a saved universe without
493 # changing save format
494 if (universeConfig := infoMap.get("universe")) is not None:
495 universe = DimensionUniverse(config=DimensionConfig(universeConfig))
496 else:
497 universe = DimensionUniverse()
498 infoMappings.universe = universe
499 self.infoMappings = infoMappings
500 return infoMappings
502 def unpackHeader(self, rawHeader: bytes) -> Optional[str]:
503 return lzma.decompress(rawHeader).decode()
505 def constructGraph(
506 self,
507 nodes: set[uuid.UUID],
508 _readBytes: Callable[[int, int], bytes],
509 universe: Optional[DimensionUniverse] = None,
510 ) -> QuantumGraph:
511 # need to import here to avoid cyclic imports
512 from . import QuantumGraph
514 graph = nx.DiGraph()
515 loadedTaskDef: Dict[str, TaskDef] = {}
516 container = {}
517 datasetDict = _DatasetTracker[DatasetTypeName, TaskDef](createInverse=True)
518 taskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set)
519 recontitutedDimensions: Dict[int, Tuple[str, DimensionRecord]] = {}
521 if universe is not None:
522 if not universe.isCompatibleWith(self.infoMappings.universe):
523 saved = self.infoMappings.universe
524 raise RuntimeError(
525 f"The saved dimension universe ({saved.namespace}@v{saved.version}) is not "
526 f"compatible with the supplied universe ({universe.namespace}@v{universe.version})."
527 )
528 else:
529 universe = self.infoMappings.universe
531 for node in nodes:
532 start, stop = self.infoMappings.map[node]["bytes"]
533 start, stop = start + self.headerSize, stop + self.headerSize
534 # Read in the bytes corresponding to the node to load and
535 # decompress it
536 dump = json.loads(lzma.decompress(_readBytes(start, stop)))
538 # Turn the json back into the pydandtic model
539 nodeDeserialized = SerializedQuantumNode.direct(**dump)
540 # attach the dictionary of dimension records to the pydandtic model
541 # these are stored seperately because the are stored over and over
542 # and this saves a lot of space and time.
543 nodeDeserialized.quantum.dimensionRecords = self.infoMappings.dimensionRecords
544 # get the label for the current task
545 nodeTaskLabel = nodeDeserialized.taskLabel
547 if nodeTaskLabel not in loadedTaskDef:
548 # Get the byte ranges corresponding to this taskDef
549 start, stop = self.infoMappings.taskDefMap[nodeTaskLabel]["bytes"]
550 start, stop = start + self.headerSize, stop + self.headerSize
552 # bytes are compressed, so decompress them
553 taskDefDump = json.loads(lzma.decompress(_readBytes(start, stop)))
554 taskClass: Type[PipelineTask] = doImportType(taskDefDump["taskName"])
555 config: PipelineTaskConfig = taskClass.ConfigClass()
556 config.loadFromStream(taskDefDump["config"])
557 # Rebuild TaskDef
558 recreatedTaskDef = TaskDef(
559 taskName=taskDefDump["taskName"],
560 taskClass=taskClass,
561 config=config,
562 label=taskDefDump["label"],
563 )
564 loadedTaskDef[nodeTaskLabel] = recreatedTaskDef
566 # rebuild the mappings that associate dataset type names with
567 # TaskDefs
568 for _, input in self.infoMappings.taskDefMap[nodeTaskLabel]["inputs"]:
569 datasetDict.addConsumer(DatasetTypeName(input), recreatedTaskDef)
571 added = set()
572 for outputConnection in self.infoMappings.taskDefMap[nodeTaskLabel]["outputs"]:
573 typeName = outputConnection[1]
574 if typeName not in added:
575 added.add(typeName)
576 datasetDict.addProducer(DatasetTypeName(typeName), recreatedTaskDef)
578 # reconstitute the node, passing in the dictionaries for the
579 # loaded TaskDefs and dimension records. These are used to ensure
580 # that each unique record is only loaded once
581 qnode = QuantumNode.from_simple(nodeDeserialized, loadedTaskDef, universe, recontitutedDimensions)
582 container[qnode.nodeId] = qnode
583 taskToQuantumNode[loadedTaskDef[nodeTaskLabel]].add(qnode)
585 # recreate the relations between each node from stored info
586 graph.add_node(qnode)
587 for id in self.infoMappings.map[qnode.nodeId]["inputs"]:
588 # uuid is stored as a string, turn it back into a uuid
589 id = uuid.UUID(id)
590 # if the id is not yet in the container, dont make a connection
591 # this is not an issue, because once it is, that id will add
592 # the reverse connection
593 if id in container:
594 graph.add_edge(container[id], qnode)
595 for id in self.infoMappings.map[qnode.nodeId]["outputs"]:
596 # uuid is stored as a string, turn it back into a uuid
597 id = uuid.UUID(id)
598 # if the id is not yet in the container, dont make a connection
599 # this is not an issue, because once it is, that id will add
600 # the reverse connection
601 if id in container:
602 graph.add_edge(qnode, container[id])
604 newGraph = object.__new__(QuantumGraph)
605 newGraph._metadata = self.infoMappings.metadata
606 newGraph._buildId = self.infoMappings._buildId
607 newGraph._datasetDict = datasetDict
608 newGraph._nodeIdMap = container
609 newGraph._count = len(nodes)
610 newGraph._taskToQuantumNode = dict(taskToQuantumNode.items())
611 newGraph._taskGraph = datasetDict.makeNetworkXGraph()
612 newGraph._connectedQuanta = graph
613 return newGraph
616DESERIALIZER_MAP = {1: DeserializerV1, 2: DeserializerV2, 3: DeserializerV3}