Coverage for python/lsst/pipe/base/pipeline_graph/io.py: 97%
201 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-05 10:01 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-05 10:01 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = (
30 "expect_not_none",
31 "SerializedEdge",
32 "SerializedTaskInitNode",
33 "SerializedTaskNode",
34 "SerializedDatasetTypeNode",
35 "SerializedTaskSubset",
36 "SerializedPipelineGraph",
37)
39from collections.abc import Mapping
40from typing import Any, TypeVar
42import networkx
43import pydantic
44from lsst.daf.butler import DatasetType, DimensionConfig, DimensionGroup, DimensionUniverse
46from .. import automatic_connection_constants as acc
47from ._dataset_types import DatasetTypeNode
48from ._edges import Edge, ReadEdge, WriteEdge
49from ._exceptions import PipelineGraphReadError
50from ._nodes import NodeKey, NodeType
51from ._pipeline_graph import PipelineGraph
52from ._task_subsets import TaskSubset
53from ._tasks import TaskImportMode, TaskInitNode, TaskNode
55_U = TypeVar("_U")
57_IO_VERSION_INFO = (0, 0, 1)
58"""Version tuple embedded in saved PipelineGraphs.
59"""
62def expect_not_none(value: _U | None, msg: str) -> _U:
63 """Check that a value is not `None` and return it.
65 Parameters
66 ----------
67 value : `~typing.Any`
68 Value to check.
69 msg : `str`
70 Error message for the case where ``value is None``.
72 Returns
73 -------
74 value : `typing.Any`
75 Value, guaranteed not to be `None`.
77 Raises
78 ------
79 PipelineGraphReadError
80 Raised with ``msg`` if ``value is None``.
81 """
82 if value is None: 82 ↛ 83line 82 didn't jump to line 83, because the condition on line 82 was never true
83 raise PipelineGraphReadError(msg)
84 return value
87class SerializedEdge(pydantic.BaseModel):
88 """Struct used to represent a serialized `Edge` in a `PipelineGraph`.
90 All `ReadEdge` and `WriteEdge` state not included here is instead
91 effectively serialized by the context in which a `SerializedEdge` appears
92 (e.g. the keys of the nested dictionaries in which it serves as the value
93 type).
94 """
96 dataset_type_name: str
97 """Full dataset type name (including component)."""
99 storage_class: str
100 """Name of the storage class."""
102 raw_dimensions: list[str]
103 """Raw dimensions of the dataset type from the task connections."""
105 is_calibration: bool = False
106 """Whether this dataset type can be included in
107 `~lsst.daf.butler.CollectionType.CALIBRATION` collections."""
109 defer_query_constraint: bool = False
110 """If `True`, by default do not include this dataset type's existence as a
111 constraint on the initial data ID query in QuantumGraph generation."""
113 @classmethod
114 def serialize(cls, target: Edge) -> SerializedEdge:
115 """Transform an `Edge` to a `SerializedEdge`.
117 Parameters
118 ----------
119 target : `Edge`
120 The object to serialize.
122 Returns
123 -------
124 `SerializedEdge`
125 Model transformed into something that can be serialized.
126 """
127 return SerializedEdge.model_construct(
128 storage_class=target.storage_class_name,
129 dataset_type_name=target.dataset_type_name,
130 raw_dimensions=sorted(target.raw_dimensions),
131 is_calibration=target.is_calibration,
132 defer_query_constraint=getattr(target, "defer_query_constraint", False),
133 )
135 def deserialize_read_edge(
136 self,
137 task_key: NodeKey,
138 connection_name: str,
139 dataset_type_keys: Mapping[str, NodeKey],
140 is_prerequisite: bool = False,
141 ) -> ReadEdge:
142 """Transform a `SerializedEdge` to a `ReadEdge`.
144 Parameters
145 ----------
146 task_key : `NodeKey`
147 Key for the task node this edge is connected to.
148 connection_name : `str`
149 Internal name for the connection as seen by the task.
150 dataset_type_keys : `~collections.abc.Mapping` [`str`, `NodeKey`]
151 Mapping of dataset type name to node key.
152 is_prerequisite : `bool`, optional
153 Whether this dataset must be present in the data repository prior
154 to `QuantumGraph` generation.
156 Returns
157 -------
158 `ReadEdge`
159 Deserialized object.
160 """
161 parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(self.dataset_type_name)
162 return ReadEdge(
163 dataset_type_key=dataset_type_keys[parent_dataset_type_name],
164 task_key=task_key,
165 storage_class_name=self.storage_class,
166 is_prerequisite=is_prerequisite,
167 component=component,
168 connection_name=connection_name,
169 is_calibration=self.is_calibration,
170 defer_query_constraint=self.defer_query_constraint,
171 raw_dimensions=frozenset(self.raw_dimensions),
172 )
174 def deserialize_write_edge(
175 self,
176 task_key: NodeKey,
177 connection_name: str,
178 dataset_type_keys: Mapping[str, NodeKey],
179 ) -> WriteEdge:
180 """Transform a `SerializedEdge` to a `WriteEdge`.
182 Parameters
183 ----------
184 task_key : `NodeKey`
185 Key for the task node this edge is connected to.
186 connection_name : `str`
187 Internal name for the connection as seen by the task.
188 dataset_type_keys : `~collections.abc.Mapping` [`str`, `NodeKey`]
189 Mapping of dataset type name to node key.
191 Returns
192 -------
193 `WriteEdge`
194 Deserialized object.
195 """
196 return WriteEdge(
197 task_key=task_key,
198 dataset_type_key=dataset_type_keys[self.dataset_type_name],
199 storage_class_name=self.storage_class,
200 connection_name=connection_name,
201 is_calibration=self.is_calibration,
202 raw_dimensions=frozenset(self.raw_dimensions),
203 )
206class SerializedTaskInitNode(pydantic.BaseModel):
207 """Struct used to represent a serialized `TaskInitNode` in a
208 `PipelineGraph`.
210 The task label is serialized by the context in which a
211 `SerializedTaskInitNode` appears (e.g. the keys of the nested dictionary
212 in which it serves as the value type), and the task class name and config
213 string are save with the corresponding `SerializedTaskNode`.
214 """
216 inputs: dict[str, SerializedEdge]
217 """Mapping of serialized init-input edges, keyed by connection name."""
219 outputs: dict[str, SerializedEdge]
220 """Mapping of serialized init-output edges, keyed by connection name."""
222 config_output: SerializedEdge
223 """The serialized config init-output edge."""
225 index: int | None = None
226 """The index of this node in the sorted sequence of `PipelineGraph`.
228 This is `None` if the `PipelineGraph` was not sorted when it was
229 serialized.
230 """
232 @classmethod
233 def serialize(cls, target: TaskInitNode) -> SerializedTaskInitNode:
234 """Transform a `TaskInitNode` to a `SerializedTaskInitNode`.
236 Parameters
237 ----------
238 target : `TaskInitNode`
239 Object to be serialized.
241 Returns
242 -------
243 `SerializedTaskInitNode`
244 Model that can be serialized.
245 """
246 return cls.model_construct(
247 inputs={
248 connection_name: SerializedEdge.serialize(edge)
249 for connection_name, edge in sorted(target.inputs.items())
250 },
251 outputs={
252 connection_name: SerializedEdge.serialize(edge)
253 for connection_name, edge in sorted(target.outputs.items())
254 },
255 config_output=SerializedEdge.serialize(target.config_output),
256 )
258 def deserialize(
259 self,
260 key: NodeKey,
261 task_class_name: str,
262 config_str: str,
263 dataset_type_keys: Mapping[str, NodeKey],
264 ) -> TaskInitNode:
265 """Transform a `SerializedTaskInitNode` to a `TaskInitNode`.
267 Parameters
268 ----------
269 key : `NodeKey`
270 Key that identifies this node in internal and exported networkx
271 graphs.
272 task_class_name : `str`, optional
273 Fully-qualified name of the task class. Must be provided if
274 ``imported_data`` is not.
275 config_str : `str`, optional
276 Configuration for the task as a string of override statements.
277 dataset_type_keys : `~collections.abc.Mapping` [`str`, `NodeKey`]
278 Mapping of dataset type name to node key.
280 Returns
281 -------
282 `TaskInitNode`
283 Deserialized object.
284 """
285 return TaskInitNode(
286 key,
287 inputs={
288 connection_name: serialized_edge.deserialize_read_edge(
289 key, connection_name, dataset_type_keys
290 )
291 for connection_name, serialized_edge in self.inputs.items()
292 },
293 outputs={
294 connection_name: serialized_edge.deserialize_write_edge(
295 key, connection_name, dataset_type_keys
296 )
297 for connection_name, serialized_edge in self.outputs.items()
298 },
299 config_output=self.config_output.deserialize_write_edge(
300 key, acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, dataset_type_keys
301 ),
302 task_class_name=task_class_name,
303 config_str=config_str,
304 )
307class SerializedTaskNode(pydantic.BaseModel):
308 """Struct used to represent a serialized `TaskNode` in a `PipelineGraph`.
310 The task label is serialized by the context in which a
311 `SerializedTaskNode` appears (e.g. the keys of the nested dictionary in
312 which it serves as the value type).
313 """
315 task_class: str
316 """Fully-qualified name of the task class."""
318 init: SerializedTaskInitNode
319 """Serialized task initialization node."""
321 config_str: str
322 """Configuration for the task as a string of override statements."""
324 prerequisite_inputs: dict[str, SerializedEdge]
325 """Mapping of serialized prerequisiste input edges, keyed by connection
326 name.
327 """
329 inputs: dict[str, SerializedEdge]
330 """Mapping of serialized input edges, keyed by connection name."""
332 outputs: dict[str, SerializedEdge]
333 """Mapping of serialized output edges, keyed by connection name."""
335 metadata_output: SerializedEdge
336 """The serialized metadata output edge."""
338 dimensions: list[str]
339 """The task's dimensions, if they were resolved."""
341 log_output: SerializedEdge | None = None
342 """The serialized log output edge."""
344 index: int | None = None
345 """The index of this node in the sorted sequence of `PipelineGraph`.
347 This is `None` if the `PipelineGraph` was not sorted when it was
348 serialized.
349 """
351 @classmethod
352 def serialize(cls, target: TaskNode) -> SerializedTaskNode:
353 """Transform a `TaskNode` to a `SerializedTaskNode`.
355 Parameters
356 ----------
357 target : `TaskNode`
358 Object to be serialized.
360 Returns
361 -------
362 `SerializedTaskNode`
363 Object tha can be serialized.
364 """
365 return cls.model_construct(
366 task_class=target.task_class_name,
367 init=SerializedTaskInitNode.serialize(target.init),
368 config_str=target.get_config_str(),
369 dimensions=list(target.raw_dimensions),
370 prerequisite_inputs={
371 connection_name: SerializedEdge.serialize(edge)
372 for connection_name, edge in sorted(target.prerequisite_inputs.items())
373 },
374 inputs={
375 connection_name: SerializedEdge.serialize(edge)
376 for connection_name, edge in sorted(target.inputs.items())
377 },
378 outputs={
379 connection_name: SerializedEdge.serialize(edge)
380 for connection_name, edge in sorted(target.outputs.items())
381 },
382 metadata_output=SerializedEdge.serialize(target.metadata_output),
383 log_output=(
384 SerializedEdge.serialize(target.log_output) if target.log_output is not None else None
385 ),
386 )
388 def deserialize(
389 self,
390 key: NodeKey,
391 init_key: NodeKey,
392 dataset_type_keys: Mapping[str, NodeKey],
393 universe: DimensionUniverse | None,
394 ) -> TaskNode:
395 """Transform a `SerializedTaskNode` to a `TaskNode`.
397 Parameters
398 ----------
399 key : `NodeKey`
400 Key that identifies this node in internal and exported networkx
401 graphs.
402 init_key : `NodeKey`
403 Key that identifies the init node in internal and exported networkx
404 graphs.
405 dataset_type_keys : `~collections.abc.Mapping` [`str`, `NodeKey`]
406 Mapping of dataset type name to node key.
407 universe : `~lsst.daf.butler.DimensionUniverse` or `None`
408 The dimension universe.
410 Returns
411 -------
412 `TaskNode`
413 Deserialized object.
414 """
415 init = self.init.deserialize(
416 init_key,
417 task_class_name=self.task_class,
418 config_str=expect_not_none(
419 self.config_str, f"No serialized config file for task with label {key.name!r}."
420 ),
421 dataset_type_keys=dataset_type_keys,
422 )
423 inputs = {
424 connection_name: serialized_edge.deserialize_read_edge(key, connection_name, dataset_type_keys)
425 for connection_name, serialized_edge in self.inputs.items()
426 }
427 prerequisite_inputs = {
428 connection_name: serialized_edge.deserialize_read_edge(
429 key, connection_name, dataset_type_keys, is_prerequisite=True
430 )
431 for connection_name, serialized_edge in self.prerequisite_inputs.items()
432 }
433 outputs = {
434 connection_name: serialized_edge.deserialize_write_edge(key, connection_name, dataset_type_keys)
435 for connection_name, serialized_edge in self.outputs.items()
436 }
437 if (serialized_log_output := self.log_output) is not None: 437 ↛ 442line 437 didn't jump to line 442, because the condition on line 437 was never false
438 log_output = serialized_log_output.deserialize_write_edge(
439 key, acc.LOG_OUTPUT_CONNECTION_NAME, dataset_type_keys
440 )
441 else:
442 log_output = None
443 metadata_output = self.metadata_output.deserialize_write_edge(
444 key, acc.METADATA_OUTPUT_CONNECTION_NAME, dataset_type_keys
445 )
446 dimensions: frozenset[str] | DimensionGroup
447 if universe is not None:
448 dimensions = universe.conform(self.dimensions)
449 else:
450 dimensions = frozenset(self.dimensions)
451 return TaskNode(
452 key=key,
453 init=init,
454 inputs=inputs,
455 prerequisite_inputs=prerequisite_inputs,
456 outputs=outputs,
457 log_output=log_output,
458 metadata_output=metadata_output,
459 dimensions=dimensions,
460 )
463class SerializedDatasetTypeNode(pydantic.BaseModel):
464 """Struct used to represent a serialized `DatasetTypeNode` in a
465 `PipelineGraph`.
467 Unresolved dataset types are serialized as instances with at most the
468 `index` attribute set, and are typically converted to JSON with pydantic's
469 ``exclude_defaults=True`` option to keep this compact.
471 The dataset typename is serialized by the context in which a
472 `SerializedDatasetTypeNode` appears (e.g. the keys of the nested dictionary
473 in which it serves as the value type).
474 """
476 dimensions: list[str] | None = None
477 """Dimensions of the dataset type."""
479 storage_class: str | None = None
480 """Name of the storage class."""
482 is_calibration: bool = False
483 """Whether this dataset type is a calibration."""
485 is_initial_query_constraint: bool = False
486 """Whether this dataset type should be a query constraint during
487 `QuantumGraph` generation."""
489 is_prerequisite: bool = False
490 """Whether datasets of this dataset type must exist in the input collection
491 before `QuantumGraph` generation."""
493 index: int | None = None
494 """The index of this node in the sorted sequence of `PipelineGraph`.
496 This is `None` if the `PipelineGraph` was not sorted when it was
497 serialized.
498 """
500 @classmethod
501 def serialize(cls, target: DatasetTypeNode | None) -> SerializedDatasetTypeNode:
502 """Transform a `DatasetTypeNode` to a `SerializedDatasetTypeNode`.
504 Parameters
505 ----------
506 target : `DatasetTypeNode` or `None`
507 Object to serialize.
509 Returns
510 -------
511 `SerializedDatasetTypeNode`
512 Object in serializable form.
513 """
514 if target is None:
515 return cls.model_construct()
516 return cls.model_construct(
517 dimensions=list(target.dataset_type.dimensions.names),
518 storage_class=target.dataset_type.storageClass_name,
519 is_calibration=target.dataset_type.isCalibration(),
520 is_initial_query_constraint=target.is_initial_query_constraint,
521 is_prerequisite=target.is_prerequisite,
522 )
524 def deserialize(
525 self, key: NodeKey, xgraph: networkx.MultiDiGraph, universe: DimensionUniverse | None
526 ) -> DatasetTypeNode | None:
527 """Transform a `SerializedDatasetTypeNode` to a `DatasetTypeNode`.
529 Parameters
530 ----------
531 key : `NodeKey`
532 Key that identifies this node in internal and exported networkx
533 graphs.
534 xgraph : `networkx.MultiDiGraph`
535 <unknown>.
536 universe : `~lsst.daf.butler.DimensionUniverse` or `None`
537 The dimension universe.
539 Returns
540 -------
541 `DatasetTypeNode`
542 Deserialized object.
543 """
544 if self.dimensions is not None:
545 dataset_type = DatasetType(
546 key.name,
547 expect_not_none(
548 self.dimensions,
549 f"Serialized dataset type {key.name!r} has no dimensions.",
550 ),
551 storageClass=expect_not_none(
552 self.storage_class,
553 f"Serialized dataset type {key.name!r} has no storage class.",
554 ),
555 isCalibration=self.is_calibration,
556 universe=expect_not_none(
557 universe,
558 f"Serialized dataset type {key.name!r} has dimensions, "
559 "but no dimension universe was stored.",
560 ),
561 )
562 producer: str | None = None
563 producing_edge: WriteEdge | None = None
564 for _, _, producing_edge in xgraph.in_edges(key, data="instance"):
565 assert producing_edge is not None, "Should only be None if we never loop."
566 if producer is not None: 566 ↛ 567line 566 didn't jump to line 567, because the condition on line 566 was never true
567 raise PipelineGraphReadError(
568 f"Serialized dataset type {key.name!r} is produced by both "
569 f"{producing_edge.task_label!r} and {producer!r} in resolved graph."
570 )
571 producer = producing_edge.task_label
572 consuming_edges = tuple(
573 consuming_edge for _, _, consuming_edge in xgraph.in_edges(key, data="instance")
574 )
575 return DatasetTypeNode(
576 dataset_type=dataset_type,
577 is_prerequisite=self.is_prerequisite,
578 is_initial_query_constraint=self.is_initial_query_constraint,
579 producing_edge=producing_edge,
580 consuming_edges=consuming_edges,
581 )
582 return None
585class SerializedTaskSubset(pydantic.BaseModel):
586 """Struct used to represent a serialized `TaskSubset` in a `PipelineGraph`.
588 The subsetlabel is serialized by the context in which a
589 `SerializedDatasetTypeNode` appears (e.g. the keys of the nested dictionary
590 in which it serves as the value type).
591 """
593 description: str
594 """Description of the subset."""
596 tasks: list[str]
597 """Labels of tasks in the subset, sorted lexicographically for
598 determinism.
599 """
601 @classmethod
602 def serialize(cls, target: TaskSubset) -> SerializedTaskSubset:
603 """Transform a `TaskSubset` into a `SerializedTaskSubset`.
605 Parameters
606 ----------
607 target : `TaskSubset`
608 Object to serialize.
610 Returns
611 -------
612 `SerializedTaskSubset`
613 Object in serializable form.
614 """
615 return cls.model_construct(description=target._description, tasks=list(sorted(target)))
617 def deserialize_task_subset(self, label: str, xgraph: networkx.MultiDiGraph) -> TaskSubset:
618 """Transform a `SerializedTaskSubset` into a `TaskSubset`.
620 Parameters
621 ----------
622 label : `str`
623 Subset label.
624 xgraph : `networkx.MultiDiGraph`
625 <unknown>.
627 Returns
628 -------
629 `TaskSubset`
630 Deserialized object.
631 """
632 members = set(self.tasks)
633 return TaskSubset(xgraph, label, members, self.description)
636class SerializedPipelineGraph(pydantic.BaseModel):
637 """Struct used to represent a serialized `PipelineGraph`."""
639 version: str = ".".join(str(v) for v in _IO_VERSION_INFO)
640 """Serialization version."""
642 description: str
643 """Human-readable description of the pipeline."""
645 tasks: dict[str, SerializedTaskNode] = pydantic.Field(default_factory=dict)
646 """Mapping of serialized tasks, keyed by label."""
648 dataset_types: dict[str, SerializedDatasetTypeNode] = pydantic.Field(default_factory=dict)
649 """Mapping of serialized dataset types, keyed by parent dataset type name.
650 """
652 task_subsets: dict[str, SerializedTaskSubset] = pydantic.Field(default_factory=dict)
653 """Mapping of task subsets, keyed by subset label."""
655 dimensions: dict[str, Any] | None = None
656 """Dimension universe configuration."""
658 data_id: dict[str, Any] = pydantic.Field(default_factory=dict)
659 """Data ID that constrains all quanta generated from this pipeline."""
661 @classmethod
662 def serialize(cls, target: PipelineGraph) -> SerializedPipelineGraph:
663 """Transform a `PipelineGraph` into a `SerializedPipelineGraph`.
665 Parameters
666 ----------
667 target : `PipelineGraph`
668 Object to serialize.
670 Returns
671 -------
672 `SerializedPipelineGraph`
673 Object in serializable form.
674 """
675 result = SerializedPipelineGraph.model_construct(
676 description=target.description,
677 tasks={label: SerializedTaskNode.serialize(node) for label, node in target.tasks.items()},
678 dataset_types={
679 name: SerializedDatasetTypeNode().serialize(target.dataset_types.get_if_resolved(name))
680 for name in target.dataset_types
681 },
682 task_subsets={
683 label: SerializedTaskSubset.serialize(subset) for label, subset in target.task_subsets.items()
684 },
685 dimensions=target.universe.dimensionConfig.toDict() if target.universe is not None else None,
686 data_id=target._raw_data_id,
687 )
688 if target._sorted_keys:
689 for index, node_key in enumerate(target._sorted_keys):
690 match node_key.node_type:
691 case NodeType.TASK:
692 result.tasks[node_key.name].index = index
693 case NodeType.DATASET_TYPE:
694 result.dataset_types[node_key.name].index = index
695 case NodeType.TASK_INIT: 695 ↛ 689line 695 didn't jump to line 689, because the pattern on line 695 always matched
696 result.tasks[node_key.name].init.index = index
697 return result
699 def deserialize(
700 self,
701 import_mode: TaskImportMode,
702 ) -> PipelineGraph:
703 """Transform a `SerializedPipelineGraph` into a `PipelineGraph`.
705 Parameters
706 ----------
707 import_mode : `TaskImportMode`
708 Import mode.
710 Returns
711 -------
712 `PipelineGraph`
713 Deserialized object.
714 """
715 universe: DimensionUniverse | None = None
716 if self.dimensions is not None:
717 universe = DimensionUniverse(
718 config=DimensionConfig(
719 expect_not_none(
720 self.dimensions,
721 "Serialized pipeline graph has not been resolved; "
722 "load it is a MutablePipelineGraph instead.",
723 )
724 )
725 )
726 xgraph = networkx.MultiDiGraph()
727 sort_index_map: dict[int, NodeKey] = {}
728 # Save the dataset type keys after the first time we make them - these
729 # may be tiny objects, but it's still to have only one copy of each
730 # value floating around the graph.
731 dataset_type_keys: dict[str, NodeKey] = {}
732 for dataset_type_name, serialized_dataset_type in self.dataset_types.items():
733 dataset_type_key = NodeKey(NodeType.DATASET_TYPE, dataset_type_name)
734 # We intentionally don't attach a DatasetTypeNode instance here
735 # yet, since we need edges to do that and those are saved with
736 # the tasks.
737 xgraph.add_node(dataset_type_key, bipartite=NodeType.DATASET_TYPE.value)
738 if serialized_dataset_type.index is not None:
739 sort_index_map[serialized_dataset_type.index] = dataset_type_key
740 dataset_type_keys[dataset_type_name] = dataset_type_key
741 for task_label, serialized_task in self.tasks.items():
742 task_key = NodeKey(NodeType.TASK, task_label)
743 task_init_key = NodeKey(NodeType.TASK_INIT, task_label)
744 task_node = serialized_task.deserialize(task_key, task_init_key, dataset_type_keys, universe)
745 if serialized_task.index is not None:
746 sort_index_map[serialized_task.index] = task_key
747 if serialized_task.init.index is not None:
748 sort_index_map[serialized_task.init.index] = task_init_key
749 xgraph.add_node(task_key, instance=task_node, bipartite=NodeType.TASK.bipartite)
750 xgraph.add_node(task_init_key, instance=task_node.init, bipartite=NodeType.TASK_INIT.bipartite)
751 xgraph.add_edge(task_init_key, task_key, Edge.INIT_TO_TASK_NAME, instance=None)
752 for read_edge in task_node.init.iter_all_inputs():
753 xgraph.add_edge(
754 read_edge.dataset_type_key,
755 read_edge.task_key,
756 read_edge.connection_name,
757 instance=read_edge,
758 )
759 for write_edge in task_node.init.iter_all_outputs():
760 xgraph.add_edge(
761 write_edge.task_key,
762 write_edge.dataset_type_key,
763 write_edge.connection_name,
764 instance=write_edge,
765 )
766 for read_edge in task_node.iter_all_inputs():
767 xgraph.add_edge(
768 read_edge.dataset_type_key,
769 read_edge.task_key,
770 read_edge.connection_name,
771 instance=read_edge,
772 )
773 for write_edge in task_node.iter_all_outputs():
774 xgraph.add_edge(
775 write_edge.task_key,
776 write_edge.dataset_type_key,
777 write_edge.connection_name,
778 instance=write_edge,
779 )
780 # Iterate over dataset types again to add instances.
781 for dataset_type_name, serialized_dataset_type in self.dataset_types.items():
782 dataset_type_key = dataset_type_keys[dataset_type_name]
783 xgraph.nodes[dataset_type_key]["instance"] = serialized_dataset_type.deserialize(
784 dataset_type_key, xgraph, universe
785 )
786 result = PipelineGraph.__new__(PipelineGraph)
787 result._init_from_args(
788 xgraph,
789 sorted_keys=[sort_index_map[i] for i in range(len(xgraph))] if sort_index_map else None,
790 task_subsets={
791 subset_label: serialized_subset.deserialize_task_subset(subset_label, xgraph)
792 for subset_label, serialized_subset in self.task_subsets.items()
793 },
794 description=self.description,
795 universe=universe,
796 data_id=self.data_id,
797 )
798 result._import_and_configure(import_mode)
799 return result