Coverage for python/lsst/pipe/base/graph/_versionDeserializers.py: 30%
240 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-14 16:10 -0700
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-14 16:10 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("DESERIALIZER_MAP",)
25import json
26import lzma
27import pickle
28import struct
29import uuid
30from abc import ABC, abstractmethod
31from collections import defaultdict
32from dataclasses import dataclass
33from types import SimpleNamespace
34from typing import TYPE_CHECKING, Callable, ClassVar, DefaultDict, Dict, Optional, Set, Tuple, Type
36import networkx as nx
37from lsst.daf.butler import (
38 DimensionConfig,
39 DimensionRecord,
40 DimensionUniverse,
41 Quantum,
42 SerializedDimensionRecord,
43)
44from lsst.utils import doImportType
46from ..config import PipelineTaskConfig
47from ..pipeline import TaskDef
48from ..pipelineTask import PipelineTask
49from ._implDetails import DatasetTypeName, _DatasetTracker
50from .quantumNode import QuantumNode, SerializedQuantumNode
52if TYPE_CHECKING: 52 ↛ 53line 52 didn't jump to line 53, because the condition on line 52 was never true
53 from .graph import QuantumGraph
56class StructSizeDescriptor:
57 """This is basically a class level property. It exists to report the size
58 (number of bytes) of whatever the formatter string is for a deserializer
59 """
61 def __get__(self, inst: Optional[DeserializerBase], owner: Type[DeserializerBase]) -> int:
62 return struct.calcsize(owner.FMT_STRING())
65# MyPy doesn't seem to like the idea of an abstract dataclass. It seems to
66# work, but maybe we're doing something that isn't really supported (or maybe
67# I misunderstood the error message).
68@dataclass # type: ignore
69class DeserializerBase(ABC):
70 @classmethod
71 @abstractmethod
72 def FMT_STRING(cls) -> str: # noqa: N805 # flake8 wants self
73 raise NotImplementedError("Base class does not implement this method")
75 structSize: ClassVar[StructSizeDescriptor]
77 preambleSize: int
78 sizeBytes: bytes
80 def __init_subclass__(cls) -> None:
81 # attach the size decriptor
82 cls.structSize = StructSizeDescriptor()
83 super().__init_subclass__()
85 def unpackHeader(self, rawHeader: bytes) -> Optional[str]:
86 """Transforms the raw bytes corresponding to the header of a save into
87 a string of the header information. Returns none if the save format has
88 no header string implementation (such as save format 1 that is all
89 pickle)
91 Parameters
92 ----------
93 rawheader : bytes
94 The bytes that are to be parsed into the header information. These
95 are the bytes after the preamble and structsize number of bytes
96 and before the headerSize bytes
97 """
98 raise NotImplementedError("Base class does not implement this method")
100 @property
101 def headerSize(self) -> int:
102 """Returns the number of bytes from the beginning of the file to the
103 end of the metadata.
104 """
105 raise NotImplementedError("Base class does not implement this method")
107 def readHeaderInfo(self, rawHeader: bytes) -> SimpleNamespace:
108 """Parse the supplied raw bytes into the header information and
109 byte ranges of specific TaskDefs and QuantumNodes
111 Parameters
112 ----------
113 rawheader : bytes
114 The bytes that are to be parsed into the header information. These
115 are the bytes after the preamble and structsize number of bytes
116 and before the headerSize bytes
117 """
118 raise NotImplementedError("Base class does not implement this method")
120 def constructGraph(
121 self,
122 nodes: set[uuid.UUID],
123 _readBytes: Callable[[int, int], bytes],
124 universe: Optional[DimensionUniverse] = None,
125 ) -> QuantumGraph:
126 """Constructs a graph from the deserialized information.
128 Parameters
129 ----------
130 nodes : `set` of `uuid.UUID`
131 The nodes to include in the graph
132 _readBytes : callable
133 A callable that can be used to read bytes from the file handle.
134 The callable will take two ints, start and stop, to use as the
135 numerical bounds to read and returns a byte stream.
136 universe : `~lsst.daf.butler.DimensionUniverse`
137 The singleton of all dimensions known to the middleware registry
138 """
139 raise NotImplementedError("Base class does not implement this method")
141 def description(self) -> str:
142 """Return the description of the serialized data format"""
143 raise NotImplementedError("Base class does not implement this method")
146Version1Description = """
147The save file starts with the first few bytes corresponding to the magic bytes
148in the QuantumGraph: `qgraph4\xf6\xe8\xa9`.
150The next few bytes are 2 big endian unsigned 64 bit integers.
152The first unsigned 64 bit integer corresponds to the number of bytes of a
153python mapping of TaskDef labels to the byte ranges in the save file where the
154definition can be loaded.
156The second unsigned 64 bit integer corrresponds to the number of bytes of a
157python mapping of QuantumGraph Node number to the byte ranges in the save file
158where the node can be loaded. The byte range is indexed starting after
159the `header` bytes of the magic bytes, size bytes, and bytes of the two
160mappings.
162Each of the above mappings are pickled and then lzma compressed, so to
163deserialize the bytes, first lzma decompression must be performed and the
164results passed to python pickle loader.
166As stated above, each map contains byte ranges of the corresponding
167datastructure. Theses bytes are also lzma compressed pickles, and should
168be deserialized in a similar manner. The byte range is indexed starting after
169the `header` bytes of the magic bytes, size bytes, and bytes of the two
170mappings.
172In addition to the the TaskDef byte locations, the TypeDef map also contains
173an additional key '__GraphBuildID'. The value associated with this is the
174unique id assigned to the graph at its creation time.
175"""
178@dataclass
179class DeserializerV1(DeserializerBase):
180 @classmethod
181 def FMT_STRING(cls) -> str:
182 return ">QQ"
184 def __post_init__(self) -> None:
185 self.taskDefMapSize, self.nodeMapSize = struct.unpack(self.FMT_STRING(), self.sizeBytes)
187 @property
188 def headerSize(self) -> int:
189 return self.preambleSize + self.structSize + self.taskDefMapSize + self.nodeMapSize
191 def readHeaderInfo(self, rawHeader: bytes) -> SimpleNamespace:
192 returnValue = SimpleNamespace()
193 returnValue.taskDefMap = pickle.loads(rawHeader[: self.taskDefMapSize])
194 returnValue._buildId = returnValue.taskDefMap["__GraphBuildID"]
195 returnValue.map = pickle.loads(rawHeader[self.taskDefMapSize :])
196 returnValue.metadata = None
197 self.returnValue = returnValue
198 return returnValue
200 def unpackHeader(self, rawHeader: bytes) -> Optional[str]:
201 return None
203 def constructGraph(
204 self,
205 nodes: set[uuid.UUID],
206 _readBytes: Callable[[int, int], bytes],
207 universe: Optional[DimensionUniverse] = None,
208 ) -> QuantumGraph:
209 # need to import here to avoid cyclic imports
210 from . import QuantumGraph
212 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
213 quantumToNodeId: Dict[Quantum, uuid.UUID] = {}
214 loadedTaskDef = {}
215 # loop over the nodes specified above
216 for node in nodes:
217 # Get the bytes to read from the map
218 start, stop = self.returnValue.map[node]
219 start += self.headerSize
220 stop += self.headerSize
222 # read the specified bytes, will be overloaded by subclasses
223 # bytes are compressed, so decompress them
224 dump = lzma.decompress(_readBytes(start, stop))
226 # reconstruct node
227 qNode = pickle.loads(dump)
228 object.__setattr__(qNode, "nodeId", uuid.uuid4())
230 # read the saved node, name. If it has been loaded, attach it, if
231 # not read in the taskDef first, and then load it
232 nodeTask = qNode.taskDef
233 if nodeTask not in loadedTaskDef:
234 # Get the byte ranges corresponding to this taskDef
235 start, stop = self.returnValue.taskDefMap[nodeTask]
236 start += self.headerSize
237 stop += self.headerSize
239 # load the taskDef, this method call will be overloaded by
240 # subclasses.
241 # bytes are compressed, so decompress them
242 taskDef = pickle.loads(lzma.decompress(_readBytes(start, stop)))
243 loadedTaskDef[nodeTask] = taskDef
244 # Explicitly overload the "frozen-ness" of nodes to attach the
245 # taskDef back into the un-persisted node
246 object.__setattr__(qNode, "taskDef", loadedTaskDef[nodeTask])
247 quanta[qNode.taskDef].add(qNode.quantum)
249 # record the node for later processing
250 quantumToNodeId[qNode.quantum] = qNode.nodeId
252 # construct an empty new QuantumGraph object, and run the associated
253 # creation method with the un-persisted data
254 qGraph = object.__new__(QuantumGraph)
255 qGraph._buildGraphs(
256 quanta,
257 _quantumToNodeId=quantumToNodeId,
258 _buildId=self.returnValue._buildId,
259 metadata=self.returnValue.metadata,
260 )
261 return qGraph
263 def description(self) -> str:
264 return Version1Description
267Version2Description = """
268The save file starts with the first few bytes corresponding to the magic bytes
269in the QuantumGraph: `qgraph4\xf6\xe8\xa9`.
271The next few bytes are a big endian unsigned long long.
273The unsigned long long corresponds to the number of bytes of a python mapping
274of header information. This mapping is encoded into json and then lzma
275compressed, meaning the operations must be performed in the opposite order to
276deserialize.
278The json encoded header mapping contains 4 fields: TaskDefs, GraphBuildId,
279Nodes, and Metadata.
281The `TaskDefs` key corresponds to a value which is a mapping of Task label to
282task data. The task data is a mapping of key to value, where the only key is
283`bytes` and it corresponds to a tuple of a byte range of the start, stop
284bytes (indexed after all the header bytes)
286The `GraphBuildId` corresponds with a string that is the unique id assigned to
287this graph when it was created.
289The `Nodes` key is like the `TaskDefs` key except it corresponds to
290QuantumNodes instead of TaskDefs. Another important difference is that JSON
291formatting does not allow using numbers as keys, and this mapping is keyed by
292the node number. Thus it is stored in JSON as two equal length lists, the first
293being the keys, and the second the values associated with those keys.
295The `Metadata` key is a mapping of strings to associated values. This metadata
296may be anything that is important to be transported alongside the graph.
298As stated above, each map contains byte ranges of the corresponding
299datastructure. Theses bytes are also lzma compressed pickles, and should
300be deserialized in a similar manner.
301"""
304@dataclass
305class DeserializerV2(DeserializerBase):
306 @classmethod
307 def FMT_STRING(cls) -> str:
308 return ">Q"
310 def __post_init__(self) -> None:
311 (self.mapSize,) = struct.unpack(self.FMT_STRING(), self.sizeBytes)
313 @property
314 def headerSize(self) -> int:
315 return self.preambleSize + self.structSize + self.mapSize
317 def readHeaderInfo(self, rawHeader: bytes) -> SimpleNamespace:
318 uncompressedHeaderMap = self.unpackHeader(rawHeader)
319 if uncompressedHeaderMap is None:
320 raise ValueError(
321 "This error is not possible because self.unpackHeader cannot return None,"
322 " but is done to satisfy type checkers"
323 )
324 header = json.loads(uncompressedHeaderMap)
325 returnValue = SimpleNamespace()
326 returnValue.taskDefMap = header["TaskDefs"]
327 returnValue._buildId = header["GraphBuildID"]
328 returnValue.map = dict(header["Nodes"])
329 returnValue.metadata = header["Metadata"]
330 self.returnValue = returnValue
331 return returnValue
333 def unpackHeader(self, rawHeader: bytes) -> Optional[str]:
334 return lzma.decompress(rawHeader).decode()
336 def constructGraph(
337 self,
338 nodes: set[uuid.UUID],
339 _readBytes: Callable[[int, int], bytes],
340 universe: Optional[DimensionUniverse] = None,
341 ) -> QuantumGraph:
342 # need to import here to avoid cyclic imports
343 from . import QuantumGraph
345 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
346 quantumToNodeId: Dict[Quantum, uuid.UUID] = {}
347 loadedTaskDef = {}
348 # loop over the nodes specified above
349 for node in nodes:
350 # Get the bytes to read from the map
351 start, stop = self.returnValue.map[node]["bytes"]
352 start += self.headerSize
353 stop += self.headerSize
355 # read the specified bytes, will be overloaded by subclasses
356 # bytes are compressed, so decompress them
357 dump = lzma.decompress(_readBytes(start, stop))
359 # reconstruct node
360 qNode = pickle.loads(dump)
361 object.__setattr__(qNode, "nodeId", uuid.uuid4())
363 # read the saved node, name. If it has been loaded, attach it, if
364 # not read in the taskDef first, and then load it
365 nodeTask = qNode.taskDef
366 if nodeTask not in loadedTaskDef:
367 # Get the byte ranges corresponding to this taskDef
368 start, stop = self.returnValue.taskDefMap[nodeTask]["bytes"]
369 start += self.headerSize
370 stop += self.headerSize
372 # load the taskDef, this method call will be overloaded by
373 # subclasses.
374 # bytes are compressed, so decompress them
375 taskDef = pickle.loads(lzma.decompress(_readBytes(start, stop)))
376 loadedTaskDef[nodeTask] = taskDef
377 # Explicitly overload the "frozen-ness" of nodes to attach the
378 # taskDef back into the un-persisted node
379 object.__setattr__(qNode, "taskDef", loadedTaskDef[nodeTask])
380 quanta[qNode.taskDef].add(qNode.quantum)
382 # record the node for later processing
383 quantumToNodeId[qNode.quantum] = qNode.nodeId
385 # construct an empty new QuantumGraph object, and run the associated
386 # creation method with the un-persisted data
387 qGraph = object.__new__(QuantumGraph)
388 qGraph._buildGraphs(
389 quanta,
390 _quantumToNodeId=quantumToNodeId,
391 _buildId=self.returnValue._buildId,
392 metadata=self.returnValue.metadata,
393 )
394 return qGraph
396 def description(self) -> str:
397 return Version2Description
400Version3Description = """
401The save file starts with the first few bytes corresponding to the magic bytes
402in the QuantumGraph: `qgraph4\xf6\xe8\xa9`.
404The next few bytes are a big endian unsigned long long.
406The unsigned long long corresponds to the number of bytes of a mapping
407of header information. This mapping is encoded into json and then lzma
408compressed, meaning the operations must be performed in the opposite order to
409deserialize.
411The json encoded header mapping contains 5 fields: GraphBuildId, TaskDefs,
412Nodes, Metadata, and DimensionRecords.
414The `GraphBuildId` key corresponds with a string that is the unique id assigned
415to this graph when it was created.
417The `TaskDefs` key corresponds to a value which is a mapping of Task label to
418task data. The task data is a mapping of key to value. The keys of this mapping
419are `bytes`, `inputs`, and `outputs`.
421The `TaskDefs` `bytes` key corresponds to a tuple of a byte range of the
422start, stop bytes (indexed after all the header bytes). This byte rage
423corresponds to a lzma compressed json mapping. This mapping has keys of
424`taskName`, corresponding to a fully qualified python class, `config` a
425pex_config string that is used to configure the class, and `label` which
426corresponds to a string that uniquely identifies the task within a given
427execution pipeline.
429The `TaskDefs` `inputs` key is associated with a list of tuples where each
430tuple is a label of a task that is considered coming before a given task, and
431the name of the dataset that is shared between the tasks (think node and edge
432in a graph sense).
434The `TaskDefs` `outputs` key is like inputs except the values in a list
435correspond to all the output connections of a task.
437The `Nodes` key is also a json mapping with keys corresponding to the UUIDs of
438QuantumNodes. The values associated with these keys is another mapping with
439the keys `bytes`, `inputs`, and `outputs`.
441`Nodes` key `bytes` corresponds to a tuple of a byte range of the start, stop
442bytes (indexed after all the header bytes). These bytes are a lzma compressed
443json mapping which contains many sub elements, this mapping will be referred to
444as the SerializedQuantumNode (related to the python class it corresponds to).
446SerializedQUantumNodes have 3 keys, `quantum` corresponding to a json mapping
447(described below) referred to as a SerializedQuantum, `taskLabel` a string
448which corresponds to a label in the `TaskDefs` mapping, and `nodeId.
450A SerializedQuantum has many keys; taskName, dataId, datasetTypeMapping,
451initInputs, inputs, outputs, dimensionRecords.
453like the `TaskDefs` key except it corresponds to
454QuantumNodes instead of TaskDefs, and the keys of the mappings are string
455representations of the UUIDs of the QuantumNodes.
457The `Metadata` key is a mapping of strings to associated values. This metadata
458may be anything that is important to be transported alongside the graph.
460As stated above, each map contains byte ranges of the corresponding
461datastructure. Theses bytes are also lzma compressed pickles, and should
462be deserialized in a similar manner.
463"""
466@dataclass
467class DeserializerV3(DeserializerBase):
468 @classmethod
469 def FMT_STRING(cls) -> str:
470 return ">Q"
472 def __post_init__(self) -> None:
473 self.infoSize: int
474 (self.infoSize,) = struct.unpack(self.FMT_STRING(), self.sizeBytes)
476 @property
477 def headerSize(self) -> int:
478 return self.preambleSize + self.structSize + self.infoSize
480 def readHeaderInfo(self, rawHeader: bytes) -> SimpleNamespace:
481 uncompressedinfoMap = self.unpackHeader(rawHeader)
482 assert uncompressedinfoMap is not None # for python typing, this variant can't be None
483 infoMap = json.loads(uncompressedinfoMap)
484 infoMappings = SimpleNamespace()
485 infoMappings.taskDefMap = infoMap["TaskDefs"]
486 infoMappings._buildId = infoMap["GraphBuildID"]
487 infoMappings.map = {uuid.UUID(k): v for k, v in infoMap["Nodes"]}
488 infoMappings.metadata = infoMap["Metadata"]
489 infoMappings.dimensionRecords = {}
490 for k, v in infoMap["DimensionRecords"].items():
491 infoMappings.dimensionRecords[int(k)] = SerializedDimensionRecord(**v)
492 # This is important to be a get call here, so that it supports versions
493 # of saved quantum graph that might not have a saved universe without
494 # changing save format
495 if (universeConfig := infoMap.get("universe")) is not None:
496 universe = DimensionUniverse(config=DimensionConfig(universeConfig))
497 else:
498 universe = DimensionUniverse()
499 infoMappings.universe = universe
500 self.infoMappings = infoMappings
501 return infoMappings
503 def unpackHeader(self, rawHeader: bytes) -> Optional[str]:
504 return lzma.decompress(rawHeader).decode()
506 def constructGraph(
507 self,
508 nodes: set[uuid.UUID],
509 _readBytes: Callable[[int, int], bytes],
510 universe: Optional[DimensionUniverse] = None,
511 ) -> QuantumGraph:
512 # need to import here to avoid cyclic imports
513 from . import QuantumGraph
515 graph = nx.DiGraph()
516 loadedTaskDef: Dict[str, TaskDef] = {}
517 container = {}
518 datasetDict = _DatasetTracker[DatasetTypeName, TaskDef](createInverse=True)
519 taskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set)
520 recontitutedDimensions: Dict[int, Tuple[str, DimensionRecord]] = {}
522 if universe is not None:
523 if not universe.isCompatibleWith(self.infoMappings.universe):
524 saved = self.infoMappings.universe
525 raise RuntimeError(
526 f"The saved dimension universe ({saved.namespace}@v{saved.version}) is not "
527 f"compatible with the supplied universe ({universe.namespace}@v{universe.version})."
528 )
529 else:
530 universe = self.infoMappings.universe
532 for node in nodes:
533 start, stop = self.infoMappings.map[node]["bytes"]
534 start, stop = start + self.headerSize, stop + self.headerSize
535 # Read in the bytes corresponding to the node to load and
536 # decompress it
537 dump = json.loads(lzma.decompress(_readBytes(start, stop)))
539 # Turn the json back into the pydandtic model
540 nodeDeserialized = SerializedQuantumNode.direct(**dump)
541 # attach the dictionary of dimension records to the pydandtic model
542 # these are stored seperately because the are stored over and over
543 # and this saves a lot of space and time.
544 nodeDeserialized.quantum.dimensionRecords = self.infoMappings.dimensionRecords
545 # get the label for the current task
546 nodeTaskLabel = nodeDeserialized.taskLabel
548 if nodeTaskLabel not in loadedTaskDef:
549 # Get the byte ranges corresponding to this taskDef
550 start, stop = self.infoMappings.taskDefMap[nodeTaskLabel]["bytes"]
551 start, stop = start + self.headerSize, stop + self.headerSize
553 # bytes are compressed, so decompress them
554 taskDefDump = json.loads(lzma.decompress(_readBytes(start, stop)))
555 taskClass: Type[PipelineTask] = doImportType(taskDefDump["taskName"])
556 config: PipelineTaskConfig = taskClass.ConfigClass()
557 config.loadFromStream(taskDefDump["config"])
558 # Rebuild TaskDef
559 recreatedTaskDef = TaskDef(
560 taskName=taskDefDump["taskName"],
561 taskClass=taskClass,
562 config=config,
563 label=taskDefDump["label"],
564 )
565 loadedTaskDef[nodeTaskLabel] = recreatedTaskDef
567 # rebuild the mappings that associate dataset type names with
568 # TaskDefs
569 for _, input in self.infoMappings.taskDefMap[nodeTaskLabel]["inputs"]:
570 datasetDict.addConsumer(DatasetTypeName(input), recreatedTaskDef)
572 added = set()
573 for outputConnection in self.infoMappings.taskDefMap[nodeTaskLabel]["outputs"]:
574 typeName = outputConnection[1]
575 if typeName not in added:
576 added.add(typeName)
577 datasetDict.addProducer(DatasetTypeName(typeName), recreatedTaskDef)
579 # reconstitute the node, passing in the dictionaries for the
580 # loaded TaskDefs and dimension records. These are used to ensure
581 # that each unique record is only loaded once
582 qnode = QuantumNode.from_simple(nodeDeserialized, loadedTaskDef, universe, recontitutedDimensions)
583 container[qnode.nodeId] = qnode
584 taskToQuantumNode[loadedTaskDef[nodeTaskLabel]].add(qnode)
586 # recreate the relations between each node from stored info
587 graph.add_node(qnode)
588 for id in self.infoMappings.map[qnode.nodeId]["inputs"]:
589 # uuid is stored as a string, turn it back into a uuid
590 id = uuid.UUID(id)
591 # if the id is not yet in the container, dont make a connection
592 # this is not an issue, because once it is, that id will add
593 # the reverse connection
594 if id in container:
595 graph.add_edge(container[id], qnode)
596 for id in self.infoMappings.map[qnode.nodeId]["outputs"]:
597 # uuid is stored as a string, turn it back into a uuid
598 id = uuid.UUID(id)
599 # if the id is not yet in the container, dont make a connection
600 # this is not an issue, because once it is, that id will add
601 # the reverse connection
602 if id in container:
603 graph.add_edge(qnode, container[id])
605 newGraph = object.__new__(QuantumGraph)
606 newGraph._metadata = self.infoMappings.metadata
607 newGraph._buildId = self.infoMappings._buildId
608 newGraph._datasetDict = datasetDict
609 newGraph._nodeIdMap = container
610 newGraph._count = len(nodes)
611 newGraph._taskToQuantumNode = dict(taskToQuantumNode.items())
612 newGraph._taskGraph = datasetDict.makeNetworkXGraph()
613 newGraph._connectedQuanta = graph
614 return newGraph
617DESERIALIZER_MAP = {1: DeserializerV1, 2: DeserializerV2, 3: DeserializerV3}