Coverage for python / lsst / pipe / base / pipeline_graph / io.py: 45%
212 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:59 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:59 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = (
30 "SerializedDatasetTypeNode",
31 "SerializedEdge",
32 "SerializedPipelineGraph",
33 "SerializedTaskInitNode",
34 "SerializedTaskNode",
35 "SerializedTaskSubset",
36)
38from collections.abc import Mapping
39from typing import Any
41import networkx
42import pydantic
44from lsst.daf.butler import DatasetType, DimensionConfig, DimensionGroup, DimensionUniverse
46from .. import automatic_connection_constants as acc
47from ._dataset_types import DatasetTypeNode
48from ._edges import Edge, ReadEdge, WriteEdge
49from ._exceptions import PipelineGraphReadError
50from ._nodes import NodeKey, NodeType
51from ._pipeline_graph import PipelineGraph
52from ._task_subsets import StepDefinitions, TaskSubset
53from ._tasks import TaskImportMode, TaskInitNode, TaskNode
55_IO_VERSION_INFO = (0, 0, 1)
56"""Version tuple embedded in saved PipelineGraphs.
57"""
60def _expect_not_none[U](value: U | None, msg: str) -> U:
61 """Check that a value is not `None` and return it.
63 Parameters
64 ----------
65 value : `~typing.Any`
66 Value to check.
67 msg : `str`
68 Error message for the case where ``value is None``.
70 Returns
71 -------
72 value : `typing.Any`
73 Value, guaranteed not to be `None`.
75 Raises
76 ------
77 PipelineGraphReadError
78 Raised with ``msg`` if ``value is None``.
79 """
80 if value is None:
81 raise PipelineGraphReadError(msg)
82 return value
85class SerializedEdge(pydantic.BaseModel):
86 """Struct used to represent a serialized `Edge` in a `PipelineGraph`.
88 All `ReadEdge` and `WriteEdge` state not included here is instead
89 effectively serialized by the context in which a `SerializedEdge` appears
90 (e.g. the keys of the nested dictionaries in which it serves as the value
91 type).
92 """
94 dataset_type_name: str
95 """Full dataset type name (including component)."""
97 storage_class: str
98 """Name of the storage class."""
100 raw_dimensions: list[str]
101 """Raw dimensions of the dataset type from the task connections."""
103 is_calibration: bool = False
104 """Whether this dataset type can be included in
105 `~lsst.daf.butler.CollectionType.CALIBRATION` collections."""
107 defer_query_constraint: bool = False
108 """If `True`, by default do not include this dataset type's existence as a
109 constraint on the initial data ID query in QuantumGraph generation."""
111 @classmethod
112 def serialize(cls, target: Edge) -> SerializedEdge:
113 """Transform an `Edge` to a `SerializedEdge`.
115 Parameters
116 ----------
117 target : `Edge`
118 The object to serialize.
120 Returns
121 -------
122 `SerializedEdge`
123 Model transformed into something that can be serialized.
124 """
125 return SerializedEdge.model_construct(
126 storage_class=target.storage_class_name,
127 dataset_type_name=target.dataset_type_name,
128 raw_dimensions=sorted(target.raw_dimensions),
129 is_calibration=target.is_calibration,
130 defer_query_constraint=getattr(target, "defer_query_constraint", False),
131 )
133 def deserialize_read_edge(
134 self,
135 task_key: NodeKey,
136 connection_name: str,
137 dataset_type_keys: Mapping[str, NodeKey],
138 is_prerequisite: bool = False,
139 ) -> ReadEdge:
140 """Transform a `SerializedEdge` to a `ReadEdge`.
142 Parameters
143 ----------
144 task_key : `NodeKey`
145 Key for the task node this edge is connected to.
146 connection_name : `str`
147 Internal name for the connection as seen by the task.
148 dataset_type_keys : `~collections.abc.Mapping` [`str`, `NodeKey`]
149 Mapping of dataset type name to node key.
150 is_prerequisite : `bool`, optional
151 Whether this dataset must be present in the data repository prior
152 to `QuantumGraph` generation.
154 Returns
155 -------
156 `ReadEdge`
157 Deserialized object.
158 """
159 parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(self.dataset_type_name)
160 return ReadEdge(
161 dataset_type_key=dataset_type_keys[parent_dataset_type_name],
162 task_key=task_key,
163 storage_class_name=self.storage_class,
164 is_prerequisite=is_prerequisite,
165 component=component,
166 connection_name=connection_name,
167 is_calibration=self.is_calibration,
168 defer_query_constraint=self.defer_query_constraint,
169 raw_dimensions=frozenset(self.raw_dimensions),
170 )
172 def deserialize_write_edge(
173 self,
174 task_key: NodeKey,
175 connection_name: str,
176 dataset_type_keys: Mapping[str, NodeKey],
177 ) -> WriteEdge:
178 """Transform a `SerializedEdge` to a `WriteEdge`.
180 Parameters
181 ----------
182 task_key : `NodeKey`
183 Key for the task node this edge is connected to.
184 connection_name : `str`
185 Internal name for the connection as seen by the task.
186 dataset_type_keys : `~collections.abc.Mapping` [`str`, `NodeKey`]
187 Mapping of dataset type name to node key.
189 Returns
190 -------
191 `WriteEdge`
192 Deserialized object.
193 """
194 return WriteEdge(
195 task_key=task_key,
196 dataset_type_key=dataset_type_keys[self.dataset_type_name],
197 storage_class_name=self.storage_class,
198 connection_name=connection_name,
199 is_calibration=self.is_calibration,
200 raw_dimensions=frozenset(self.raw_dimensions),
201 )
204class SerializedTaskInitNode(pydantic.BaseModel):
205 """Struct used to represent a serialized `TaskInitNode` in a
206 `PipelineGraph`.
208 The task label is serialized by the context in which a
209 `SerializedTaskInitNode` appears (e.g. the keys of the nested dictionary
210 in which it serves as the value type), and the task class name and config
211 string are save with the corresponding `SerializedTaskNode`.
212 """
214 inputs: dict[str, SerializedEdge]
215 """Mapping of serialized init-input edges, keyed by connection name."""
217 outputs: dict[str, SerializedEdge]
218 """Mapping of serialized init-output edges, keyed by connection name."""
220 config_output: SerializedEdge
221 """The serialized config init-output edge."""
223 index: int | None = None
224 """The index of this node in the sorted sequence of `PipelineGraph`.
226 This is `None` if the `PipelineGraph` was not sorted when it was
227 serialized.
228 """
230 @classmethod
231 def serialize(cls, target: TaskInitNode) -> SerializedTaskInitNode:
232 """Transform a `TaskInitNode` to a `SerializedTaskInitNode`.
234 Parameters
235 ----------
236 target : `TaskInitNode`
237 Object to be serialized.
239 Returns
240 -------
241 `SerializedTaskInitNode`
242 Model that can be serialized.
243 """
244 return cls.model_construct(
245 inputs={
246 connection_name: SerializedEdge.serialize(edge)
247 for connection_name, edge in sorted(target.inputs.items())
248 },
249 outputs={
250 connection_name: SerializedEdge.serialize(edge)
251 for connection_name, edge in sorted(target.outputs.items())
252 },
253 config_output=SerializedEdge.serialize(target.config_output),
254 )
256 def deserialize(
257 self,
258 key: NodeKey,
259 task_class_name: str,
260 config_str: str,
261 dataset_type_keys: Mapping[str, NodeKey],
262 ) -> TaskInitNode:
263 """Transform a `SerializedTaskInitNode` to a `TaskInitNode`.
265 Parameters
266 ----------
267 key : `NodeKey`
268 Key that identifies this node in internal and exported networkx
269 graphs.
270 task_class_name : `str`, optional
271 Fully-qualified name of the task class. Must be provided if
272 ``imported_data`` is not.
273 config_str : `str`, optional
274 Configuration for the task as a string of override statements.
275 dataset_type_keys : `~collections.abc.Mapping` [`str`, `NodeKey`]
276 Mapping of dataset type name to node key.
278 Returns
279 -------
280 `TaskInitNode`
281 Deserialized object.
282 """
283 return TaskInitNode(
284 key,
285 inputs={
286 connection_name: serialized_edge.deserialize_read_edge(
287 key, connection_name, dataset_type_keys
288 )
289 for connection_name, serialized_edge in self.inputs.items()
290 },
291 outputs={
292 connection_name: serialized_edge.deserialize_write_edge(
293 key, connection_name, dataset_type_keys
294 )
295 for connection_name, serialized_edge in self.outputs.items()
296 },
297 config_output=self.config_output.deserialize_write_edge(
298 key, acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, dataset_type_keys
299 ),
300 task_class_name=task_class_name,
301 config_str=config_str,
302 )
305class SerializedTaskNode(pydantic.BaseModel):
306 """Struct used to represent a serialized `TaskNode` in a `PipelineGraph`.
308 The task label is serialized by the context in which a
309 `SerializedTaskNode` appears (e.g. the keys of the nested dictionary in
310 which it serves as the value type).
311 """
313 task_class: str
314 """Fully-qualified name of the task class."""
316 init: SerializedTaskInitNode
317 """Serialized task initialization node."""
319 config_str: str
320 """Configuration for the task as a string of override statements."""
322 prerequisite_inputs: dict[str, SerializedEdge]
323 """Mapping of serialized prerequisiste input edges, keyed by connection
324 name.
325 """
327 inputs: dict[str, SerializedEdge]
328 """Mapping of serialized input edges, keyed by connection name."""
330 outputs: dict[str, SerializedEdge]
331 """Mapping of serialized output edges, keyed by connection name."""
333 metadata_output: SerializedEdge
334 """The serialized metadata output edge."""
336 dimensions: list[str]
337 """The task's dimensions, if they were resolved."""
339 log_output: SerializedEdge | None = None
340 """The serialized log output edge."""
342 index: int | None = None
343 """The index of this node in the sorted sequence of `PipelineGraph`.
345 This is `None` if the `PipelineGraph` was not sorted when it was
346 serialized.
347 """
349 @classmethod
350 def serialize(cls, target: TaskNode) -> SerializedTaskNode:
351 """Transform a `TaskNode` to a `SerializedTaskNode`.
353 Parameters
354 ----------
355 target : `TaskNode`
356 Object to be serialized.
358 Returns
359 -------
360 `SerializedTaskNode`
361 Object that can be serialized.
362 """
363 dimensions = list(target.raw_dimensions)
364 dimensions.sort()
365 return cls.model_construct(
366 task_class=target.task_class_name,
367 init=SerializedTaskInitNode.serialize(target.init),
368 config_str=target.get_config_str(),
369 dimensions=dimensions,
370 prerequisite_inputs={
371 connection_name: SerializedEdge.serialize(edge)
372 for connection_name, edge in sorted(target.prerequisite_inputs.items())
373 },
374 inputs={
375 connection_name: SerializedEdge.serialize(edge)
376 for connection_name, edge in sorted(target.inputs.items())
377 },
378 outputs={
379 connection_name: SerializedEdge.serialize(edge)
380 for connection_name, edge in sorted(target.outputs.items())
381 },
382 metadata_output=SerializedEdge.serialize(target.metadata_output),
383 log_output=(
384 SerializedEdge.serialize(target.log_output) if target.log_output is not None else None
385 ),
386 )
388 def deserialize(
389 self,
390 key: NodeKey,
391 init_key: NodeKey,
392 dataset_type_keys: Mapping[str, NodeKey],
393 universe: DimensionUniverse | None,
394 ) -> TaskNode:
395 """Transform a `SerializedTaskNode` to a `TaskNode`.
397 Parameters
398 ----------
399 key : `NodeKey`
400 Key that identifies this node in internal and exported networkx
401 graphs.
402 init_key : `NodeKey`
403 Key that identifies the init node in internal and exported networkx
404 graphs.
405 dataset_type_keys : `~collections.abc.Mapping` [`str`, `NodeKey`]
406 Mapping of dataset type name to node key.
407 universe : `~lsst.daf.butler.DimensionUniverse` or `None`
408 The dimension universe.
410 Returns
411 -------
412 `TaskNode`
413 Deserialized object.
414 """
415 init = self.init.deserialize(
416 init_key,
417 task_class_name=self.task_class,
418 config_str=_expect_not_none(
419 self.config_str, f"No serialized config file for task with label {key.name!r}."
420 ),
421 dataset_type_keys=dataset_type_keys,
422 )
423 inputs = {
424 connection_name: serialized_edge.deserialize_read_edge(key, connection_name, dataset_type_keys)
425 for connection_name, serialized_edge in self.inputs.items()
426 }
427 prerequisite_inputs = {
428 connection_name: serialized_edge.deserialize_read_edge(
429 key, connection_name, dataset_type_keys, is_prerequisite=True
430 )
431 for connection_name, serialized_edge in self.prerequisite_inputs.items()
432 }
433 outputs = {
434 connection_name: serialized_edge.deserialize_write_edge(key, connection_name, dataset_type_keys)
435 for connection_name, serialized_edge in self.outputs.items()
436 }
437 if (serialized_log_output := self.log_output) is not None:
438 log_output = serialized_log_output.deserialize_write_edge(
439 key, acc.LOG_OUTPUT_CONNECTION_NAME, dataset_type_keys
440 )
441 else:
442 log_output = None
443 metadata_output = self.metadata_output.deserialize_write_edge(
444 key, acc.METADATA_OUTPUT_CONNECTION_NAME, dataset_type_keys
445 )
446 dimensions: frozenset[str] | DimensionGroup
447 if universe is not None:
448 dimensions = universe.conform(self.dimensions)
449 else:
450 dimensions = frozenset(self.dimensions)
451 return TaskNode(
452 key=key,
453 init=init,
454 inputs=inputs,
455 prerequisite_inputs=prerequisite_inputs,
456 outputs=outputs,
457 log_output=log_output,
458 metadata_output=metadata_output,
459 dimensions=dimensions,
460 )
463class SerializedDatasetTypeNode(pydantic.BaseModel):
464 """Struct used to represent a serialized `DatasetTypeNode` in a
465 `PipelineGraph`.
467 Unresolved dataset types are serialized as instances with at most the
468 `index` attribute set, and are typically converted to JSON with pydantic's
469 ``exclude_defaults=True`` option to keep this compact.
471 The dataset typename is serialized by the context in which a
472 `SerializedDatasetTypeNode` appears (e.g. the keys of the nested dictionary
473 in which it serves as the value type).
474 """
476 dimensions: list[str] | None = None
477 """Dimensions of the dataset type."""
479 storage_class: str | None = None
480 """Name of the storage class."""
482 is_calibration: bool = False
483 """Whether this dataset type is a calibration."""
485 is_initial_query_constraint: bool = False
486 """Whether this dataset type should be a query constraint during
487 `QuantumGraph` generation."""
489 is_prerequisite: bool = False
490 """Whether datasets of this dataset type must exist in the input collection
491 before `QuantumGraph` generation."""
493 index: int | None = None
494 """The index of this node in the sorted sequence of `PipelineGraph`.
496 This is `None` if the `PipelineGraph` was not sorted when it was
497 serialized.
498 """
500 @classmethod
501 def serialize(cls, target: DatasetTypeNode | None) -> SerializedDatasetTypeNode:
502 """Transform a `DatasetTypeNode` to a `SerializedDatasetTypeNode`.
504 Parameters
505 ----------
506 target : `DatasetTypeNode` or `None`
507 Object to serialize.
509 Returns
510 -------
511 `SerializedDatasetTypeNode`
512 Object in serializable form.
513 """
514 if target is None:
515 return cls.model_construct()
516 return cls.model_construct(
517 dimensions=list(target.dataset_type.dimensions.names),
518 storage_class=target.dataset_type.storageClass_name,
519 is_calibration=target.dataset_type.isCalibration(),
520 is_initial_query_constraint=target.is_initial_query_constraint,
521 is_prerequisite=target.is_prerequisite,
522 )
524 def deserialize(
525 self, key: NodeKey, xgraph: networkx.MultiDiGraph, universe: DimensionUniverse | None
526 ) -> DatasetTypeNode | None:
527 """Transform a `SerializedDatasetTypeNode` to a `DatasetTypeNode`.
529 Parameters
530 ----------
531 key : `NodeKey`
532 Key that identifies this node in internal and exported networkx
533 graphs.
534 xgraph : `networkx.MultiDiGraph`
535 <unknown>.
536 universe : `~lsst.daf.butler.DimensionUniverse` or `None`
537 The dimension universe.
539 Returns
540 -------
541 `DatasetTypeNode`
542 Deserialized object.
543 """
544 if self.dimensions is not None:
545 dataset_type = DatasetType(
546 key.name,
547 _expect_not_none(
548 self.dimensions,
549 f"Serialized dataset type {key.name!r} has no dimensions.",
550 ),
551 storageClass=_expect_not_none(
552 self.storage_class,
553 f"Serialized dataset type {key.name!r} has no storage class.",
554 ),
555 isCalibration=self.is_calibration,
556 universe=_expect_not_none(
557 universe,
558 f"Serialized dataset type {key.name!r} has dimensions, "
559 "but no dimension universe was stored.",
560 ),
561 )
562 producer: str | None = None
563 producing_edge: WriteEdge | None = None
564 for _, _, producing_edge in xgraph.in_edges(key, data="instance"):
565 assert producing_edge is not None, "Should only be None if we never loop."
566 if producer is not None:
567 raise PipelineGraphReadError(
568 f"Serialized dataset type {key.name!r} is produced by both "
569 f"{producing_edge.task_label!r} and {producer!r} in resolved graph."
570 )
571 producer = producing_edge.task_label
572 consuming_edges = tuple(
573 consuming_edge for _, _, consuming_edge in xgraph.in_edges(key, data="instance")
574 )
575 return DatasetTypeNode(
576 dataset_type=dataset_type,
577 is_prerequisite=self.is_prerequisite,
578 is_initial_query_constraint=self.is_initial_query_constraint,
579 producing_edge=producing_edge,
580 consuming_edges=consuming_edges,
581 )
582 return None
585class SerializedTaskSubset(pydantic.BaseModel):
586 """Struct used to represent a serialized `TaskSubset` in a `PipelineGraph`.
588 The subset label is serialized by the context in which a
589 `SerializedDatasetTypeNode` appears (e.g. the keys of the nested dictionary
590 in which it serves as the value type).
591 """
593 description: str
594 """Description of the subset."""
596 tasks: list[str]
597 """Labels of tasks in the subset, sorted lexicographically for
598 determinism.
599 """
601 dimensions: list[str]
602 """Dimensions that can be used to divide this step's quanta into
603 independent groups.
604 """
606 @classmethod
607 def serialize(cls, target: TaskSubset) -> SerializedTaskSubset:
608 """Transform a `TaskSubset` into a `SerializedTaskSubset`.
610 Parameters
611 ----------
612 target : `TaskSubset`
613 Object to serialize.
615 Returns
616 -------
617 `SerializedTaskSubset`
618 Object in serializable form.
619 """
620 dimensions = sorted(target._step_definitions._dimensions_by_label.get(target.label, ()))
621 return cls.model_construct(
622 description=target._description,
623 tasks=sorted(target),
624 dimensions=dimensions,
625 )
627 def deserialize_task_subset(
628 self, label: str, xgraph: networkx.MultiDiGraph, steps: StepDefinitions
629 ) -> TaskSubset:
630 """Transform a `SerializedTaskSubset` into a `TaskSubset`.
632 Parameters
633 ----------
634 label : `str`
635 Subset label.
636 xgraph : `networkx.MultiDiGraph`
637 The under-construction networkx graph that backs the pipeline
638 graph.
639 steps : `StepDefinitions`
640 Step definitions for the pipeline graph. Modified in-place to
641 set sharding dimension if this is a step.
643 Returns
644 -------
645 `TaskSubset`
646 Deserialized object.
647 """
648 members = set(self.tasks)
649 if label in steps:
650 steps._dimensions_by_label[label] = frozenset(self.dimensions)
651 return TaskSubset(xgraph, label, members, self.description, steps)
654class SerializedPipelineGraph(pydantic.BaseModel):
655 """Struct used to represent a serialized `PipelineGraph`."""
657 version: str = ".".join(str(v) for v in _IO_VERSION_INFO)
658 """Serialization version."""
660 description: str
661 """Human-readable description of the pipeline."""
663 tasks: dict[str, SerializedTaskNode] = pydantic.Field(default_factory=dict)
664 """Mapping of serialized tasks, keyed by label."""
666 dataset_types: dict[str, SerializedDatasetTypeNode] = pydantic.Field(default_factory=dict)
667 """Mapping of serialized dataset types, keyed by parent dataset type name.
668 """
670 task_subsets: dict[str, SerializedTaskSubset] = pydantic.Field(default_factory=dict)
671 """Mapping of task subsets, keyed by subset label."""
673 dimensions: dict[str, Any] | None = None
674 """Dimension universe configuration."""
676 data_id: dict[str, Any] = pydantic.Field(default_factory=dict)
677 """Data ID that constrains all quanta generated from this pipeline."""
679 step_labels: list[str] = pydantic.Field(default_factory=list)
680 """List of task subset labels that are steps."""
682 steps_verified: bool
683 """Whether the step definitions in this pipeline were checked for
684 consistency with the task and subset definitions.
685 """
687 @classmethod
688 def serialize(cls, target: PipelineGraph) -> SerializedPipelineGraph:
689 """Transform a `PipelineGraph` into a `SerializedPipelineGraph`.
691 Parameters
692 ----------
693 target : `PipelineGraph`
694 Object to serialize.
696 Returns
697 -------
698 `SerializedPipelineGraph`
699 Object in serializable form.
700 """
701 result = SerializedPipelineGraph.model_construct(
702 description=target.description,
703 tasks={label: SerializedTaskNode.serialize(node) for label, node in target.tasks.items()},
704 dataset_types={
705 name: SerializedDatasetTypeNode().serialize(target.dataset_types.get_if_resolved(name))
706 for name in target.dataset_types
707 },
708 task_subsets={
709 label: SerializedTaskSubset.serialize(subset) for label, subset in target.task_subsets.items()
710 },
711 step_labels=list(target.steps),
712 steps_verified=target.steps.verified,
713 dimensions=target._universe.dimensionConfig.toDict() if target._universe is not None else None,
714 data_id=target._raw_data_id,
715 )
716 if target._sorted_keys:
717 for index, node_key in enumerate(target._sorted_keys):
718 match node_key.node_type:
719 case NodeType.TASK:
720 result.tasks[node_key.name].index = index
721 case NodeType.DATASET_TYPE:
722 result.dataset_types[node_key.name].index = index
723 case NodeType.TASK_INIT:
724 result.tasks[node_key.name].init.index = index
725 return result
727 def deserialize(
728 self,
729 import_mode: TaskImportMode,
730 ) -> PipelineGraph:
731 """Transform a `SerializedPipelineGraph` into a `PipelineGraph`.
733 Parameters
734 ----------
735 import_mode : `TaskImportMode`
736 Import mode.
738 Returns
739 -------
740 `PipelineGraph`
741 Deserialized object.
742 """
743 universe: DimensionUniverse | None = None
744 if self.dimensions is not None:
745 universe = DimensionUniverse(
746 config=DimensionConfig(
747 _expect_not_none(
748 self.dimensions,
749 "Serialized pipeline graph has not been resolved; "
750 "load it is a MutablePipelineGraph instead.",
751 )
752 )
753 )
754 xgraph = networkx.MultiDiGraph()
755 sort_index_map: dict[int, NodeKey] = {}
756 # Save the dataset type keys after the first time we make them - these
757 # may be tiny objects, but it's still to have only one copy of each
758 # value floating around the graph.
759 dataset_type_keys: dict[str, NodeKey] = {}
760 for dataset_type_name, serialized_dataset_type in self.dataset_types.items():
761 dataset_type_key = NodeKey(NodeType.DATASET_TYPE, dataset_type_name)
762 # We intentionally don't attach a DatasetTypeNode instance here
763 # yet, since we need edges to do that and those are saved with
764 # the tasks.
765 xgraph.add_node(dataset_type_key, bipartite=NodeType.DATASET_TYPE.value)
766 if serialized_dataset_type.index is not None:
767 sort_index_map[serialized_dataset_type.index] = dataset_type_key
768 dataset_type_keys[dataset_type_name] = dataset_type_key
769 for task_label, serialized_task in self.tasks.items():
770 task_key = NodeKey(NodeType.TASK, task_label)
771 task_init_key = NodeKey(NodeType.TASK_INIT, task_label)
772 task_node = serialized_task.deserialize(task_key, task_init_key, dataset_type_keys, universe)
773 if serialized_task.index is not None:
774 sort_index_map[serialized_task.index] = task_key
775 if serialized_task.init.index is not None:
776 sort_index_map[serialized_task.init.index] = task_init_key
777 xgraph.add_node(task_key, instance=task_node, bipartite=NodeType.TASK.bipartite)
778 xgraph.add_node(task_init_key, instance=task_node.init, bipartite=NodeType.TASK_INIT.bipartite)
779 xgraph.add_edge(task_init_key, task_key, Edge.INIT_TO_TASK_NAME, instance=None)
780 for read_edge in task_node.init.iter_all_inputs():
781 xgraph.add_edge(
782 read_edge.dataset_type_key,
783 read_edge.task_key,
784 read_edge.connection_name,
785 instance=read_edge,
786 )
787 for write_edge in task_node.init.iter_all_outputs():
788 xgraph.add_edge(
789 write_edge.task_key,
790 write_edge.dataset_type_key,
791 write_edge.connection_name,
792 instance=write_edge,
793 )
794 for read_edge in task_node.iter_all_inputs():
795 xgraph.add_edge(
796 read_edge.dataset_type_key,
797 read_edge.task_key,
798 read_edge.connection_name,
799 instance=read_edge,
800 )
801 for write_edge in task_node.iter_all_outputs():
802 xgraph.add_edge(
803 write_edge.task_key,
804 write_edge.dataset_type_key,
805 write_edge.connection_name,
806 instance=write_edge,
807 )
808 # Iterate over dataset types again to add instances.
809 for dataset_type_name, serialized_dataset_type in self.dataset_types.items():
810 dataset_type_key = dataset_type_keys[dataset_type_name]
811 xgraph.nodes[dataset_type_key]["instance"] = serialized_dataset_type.deserialize(
812 dataset_type_key, xgraph, universe
813 )
814 steps = StepDefinitions(
815 universe,
816 dimensions_by_label=dict.fromkeys(self.step_labels, frozenset()),
817 verified=self.steps_verified,
818 )
819 result = PipelineGraph.__new__(PipelineGraph)
820 result._init_from_args(
821 xgraph,
822 sorted_keys=[sort_index_map[i] for i in range(len(xgraph))] if sort_index_map else None,
823 task_subsets={
824 subset_label: serialized_subset.deserialize_task_subset(subset_label, xgraph, steps)
825 for subset_label, serialized_subset in self.task_subsets.items()
826 },
827 description=self.description,
828 step_definitions=steps,
829 universe=universe,
830 data_id=self.data_id,
831 )
832 result._import_and_configure(import_mode)
833 return result