Coverage for python/lsst/pipe/base/pipeline_graph/io.py: 98%
202 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:30 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:30 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "expect_not_none",
25 "SerializedEdge",
26 "SerializedTaskInitNode",
27 "SerializedTaskNode",
28 "SerializedDatasetTypeNode",
29 "SerializedTaskSubset",
30 "SerializedPipelineGraph",
31)
33from collections.abc import Mapping
34from typing import Any, TypeVar
36import networkx
37import pydantic
38from lsst.daf.butler import DatasetType, DimensionConfig, DimensionGraph, DimensionUniverse
39from lsst.daf.butler._compat import _BaseModelCompat
41from .. import automatic_connection_constants as acc
42from ._dataset_types import DatasetTypeNode
43from ._edges import Edge, ReadEdge, WriteEdge
44from ._exceptions import PipelineGraphReadError
45from ._nodes import NodeKey, NodeType
46from ._pipeline_graph import PipelineGraph
47from ._task_subsets import TaskSubset
48from ._tasks import TaskImportMode, TaskInitNode, TaskNode
50_U = TypeVar("_U")
52_IO_VERSION_INFO = (0, 0, 1)
53"""Version tuple embedded in saved PipelineGraphs.
54"""
57def expect_not_none(value: _U | None, msg: str) -> _U:
58 """Check that a value is not `None` and return it.
60 Parameters
61 ----------
62 value
63 Value to check
64 msg
65 Error message for the case where ``value is None``.
67 Returns
68 -------
69 value
70 Value, guaranteed not to be `None`.
72 Raises
73 ------
74 PipelineGraphReadError
75 Raised with ``msg`` if ``value is None``.
76 """
77 if value is None: 77 ↛ 78line 77 didn't jump to line 78, because the condition on line 77 was never true
78 raise PipelineGraphReadError(msg)
79 return value
82class SerializedEdge(_BaseModelCompat):
83 """Struct used to represent a serialized `Edge` in a `PipelineGraph`.
85 All `ReadEdge` and `WriteEdge` state not included here is instead
86 effectively serialized by the context in which a `SerializedEdge` appears
87 (e.g. the keys of the nested dictionaries in which it serves as the value
88 type).
89 """
91 dataset_type_name: str
92 """Full dataset type name (including component)."""
94 storage_class: str
95 """Name of the storage class."""
97 raw_dimensions: list[str]
98 """Raw dimensions of the dataset type from the task connections."""
100 is_calibration: bool = False
101 """Whether this dataset type can be included in
102 `~lsst.daf.butler.CollectionType.CALIBRATION` collections."""
104 defer_query_constraint: bool = False
105 """If `True`, by default do not include this dataset type's existence as a
106 constraint on the initial data ID query in QuantumGraph generation."""
108 @classmethod
109 def serialize(cls, target: Edge) -> SerializedEdge:
110 """Transform an `Edge` to a `SerializedEdge`."""
111 return SerializedEdge.model_construct(
112 storage_class=target.storage_class_name,
113 dataset_type_name=target.dataset_type_name,
114 raw_dimensions=sorted(target.raw_dimensions),
115 is_calibration=target.is_calibration,
116 defer_query_constraint=getattr(target, "defer_query_constraint", False),
117 )
119 def deserialize_read_edge(
120 self,
121 task_key: NodeKey,
122 connection_name: str,
123 dataset_type_keys: Mapping[str, NodeKey],
124 is_prerequisite: bool = False,
125 ) -> ReadEdge:
126 """Transform a `SerializedEdge` to a `ReadEdge`."""
127 parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(self.dataset_type_name)
128 return ReadEdge(
129 dataset_type_key=dataset_type_keys[parent_dataset_type_name],
130 task_key=task_key,
131 storage_class_name=self.storage_class,
132 is_prerequisite=is_prerequisite,
133 component=component,
134 connection_name=connection_name,
135 is_calibration=self.is_calibration,
136 defer_query_constraint=self.defer_query_constraint,
137 raw_dimensions=frozenset(self.raw_dimensions),
138 )
140 def deserialize_write_edge(
141 self,
142 task_key: NodeKey,
143 connection_name: str,
144 dataset_type_keys: Mapping[str, NodeKey],
145 ) -> WriteEdge:
146 """Transform a `SerializedEdge` to a `WriteEdge`."""
147 return WriteEdge(
148 task_key=task_key,
149 dataset_type_key=dataset_type_keys[self.dataset_type_name],
150 storage_class_name=self.storage_class,
151 connection_name=connection_name,
152 is_calibration=self.is_calibration,
153 raw_dimensions=frozenset(self.raw_dimensions),
154 )
157class SerializedTaskInitNode(_BaseModelCompat):
158 """Struct used to represent a serialized `TaskInitNode` in a
159 `PipelineGraph`.
161 The task label is serialized by the context in which a
162 `SerializedTaskInitNode` appears (e.g. the keys of the nested dictionary
163 in which it serves as the value type), and the task class name and config
164 string are save with the corresponding `SerializedTaskNode`.
165 """
167 inputs: dict[str, SerializedEdge]
168 """Mapping of serialized init-input edges, keyed by connection name."""
170 outputs: dict[str, SerializedEdge]
171 """Mapping of serialized init-output edges, keyed by connection name."""
173 config_output: SerializedEdge
174 """The serialized config init-output edge."""
176 index: int | None = None
177 """The index of this node in the sorted sequence of `PipelineGraph`.
179 This is `None` if the `PipelineGraph` was not sorted when it was
180 serialized.
181 """
183 @classmethod
184 def serialize(cls, target: TaskInitNode) -> SerializedTaskInitNode:
185 """Transform a `TaskInitNode` to a `SerializedTaskInitNode`."""
186 return cls.model_construct(
187 inputs={
188 connection_name: SerializedEdge.serialize(edge)
189 for connection_name, edge in sorted(target.inputs.items())
190 },
191 outputs={
192 connection_name: SerializedEdge.serialize(edge)
193 for connection_name, edge in sorted(target.outputs.items())
194 },
195 config_output=SerializedEdge.serialize(target.config_output),
196 )
198 def deserialize(
199 self,
200 key: NodeKey,
201 task_class_name: str,
202 config_str: str,
203 dataset_type_keys: Mapping[str, NodeKey],
204 ) -> TaskInitNode:
205 """Transform a `SerializedTaskInitNode` to a `TaskInitNode`."""
206 return TaskInitNode(
207 key,
208 inputs={
209 connection_name: serialized_edge.deserialize_read_edge(
210 key, connection_name, dataset_type_keys
211 )
212 for connection_name, serialized_edge in self.inputs.items()
213 },
214 outputs={
215 connection_name: serialized_edge.deserialize_write_edge(
216 key, connection_name, dataset_type_keys
217 )
218 for connection_name, serialized_edge in self.outputs.items()
219 },
220 config_output=self.config_output.deserialize_write_edge(
221 key, acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, dataset_type_keys
222 ),
223 task_class_name=task_class_name,
224 config_str=config_str,
225 )
228class SerializedTaskNode(_BaseModelCompat):
229 """Struct used to represent a serialized `TaskNode` in a `PipelineGraph`.
231 The task label is serialized by the context in which a
232 `SerializedTaskNode` appears (e.g. the keys of the nested dictionary in
233 which it serves as the value type).
234 """
236 task_class: str
237 """Fully-qualified name of the task class."""
239 init: SerializedTaskInitNode
240 """Serialized task initialization node."""
242 config_str: str
243 """Configuration for the task as a string of override statements."""
245 prerequisite_inputs: dict[str, SerializedEdge]
246 """Mapping of serialized prerequisiste input edges, keyed by connection
247 name.
248 """
250 inputs: dict[str, SerializedEdge]
251 """Mapping of serialized input edges, keyed by connection name."""
253 outputs: dict[str, SerializedEdge]
254 """Mapping of serialized output edges, keyed by connection name."""
256 metadata_output: SerializedEdge
257 """The serialized metadata output edge."""
259 dimensions: list[str]
260 """The task's dimensions, if they were resolved."""
262 log_output: SerializedEdge | None = None
263 """The serialized log output edge."""
265 index: int | None = None
266 """The index of this node in the sorted sequence of `PipelineGraph`.
268 This is `None` if the `PipelineGraph` was not sorted when it was
269 serialized.
270 """
272 @classmethod
273 def serialize(cls, target: TaskNode) -> SerializedTaskNode:
274 """Transform a `TaskNode` to a `SerializedTaskNode`."""
275 return cls.model_construct(
276 task_class=target.task_class_name,
277 init=SerializedTaskInitNode.serialize(target.init),
278 config_str=target.get_config_str(),
279 dimensions=list(target.raw_dimensions),
280 prerequisite_inputs={
281 connection_name: SerializedEdge.serialize(edge)
282 for connection_name, edge in sorted(target.prerequisite_inputs.items())
283 },
284 inputs={
285 connection_name: SerializedEdge.serialize(edge)
286 for connection_name, edge in sorted(target.inputs.items())
287 },
288 outputs={
289 connection_name: SerializedEdge.serialize(edge)
290 for connection_name, edge in sorted(target.outputs.items())
291 },
292 metadata_output=SerializedEdge.serialize(target.metadata_output),
293 log_output=(
294 SerializedEdge.serialize(target.log_output) if target.log_output is not None else None
295 ),
296 )
298 def deserialize(
299 self,
300 key: NodeKey,
301 init_key: NodeKey,
302 dataset_type_keys: Mapping[str, NodeKey],
303 universe: DimensionUniverse | None,
304 ) -> TaskNode:
305 """Transform a `SerializedTaskNode` to a `TaskNode`."""
306 init = self.init.deserialize(
307 init_key,
308 task_class_name=self.task_class,
309 config_str=expect_not_none(
310 self.config_str, f"No serialized config file for task with label {key.name!r}."
311 ),
312 dataset_type_keys=dataset_type_keys,
313 )
314 inputs = {
315 connection_name: serialized_edge.deserialize_read_edge(key, connection_name, dataset_type_keys)
316 for connection_name, serialized_edge in self.inputs.items()
317 }
318 prerequisite_inputs = {
319 connection_name: serialized_edge.deserialize_read_edge(
320 key, connection_name, dataset_type_keys, is_prerequisite=True
321 )
322 for connection_name, serialized_edge in self.prerequisite_inputs.items()
323 }
324 outputs = {
325 connection_name: serialized_edge.deserialize_write_edge(key, connection_name, dataset_type_keys)
326 for connection_name, serialized_edge in self.outputs.items()
327 }
328 if (serialized_log_output := self.log_output) is not None: 328 ↛ 333line 328 didn't jump to line 333, because the condition on line 328 was never false
329 log_output = serialized_log_output.deserialize_write_edge(
330 key, acc.LOG_OUTPUT_CONNECTION_NAME, dataset_type_keys
331 )
332 else:
333 log_output = None
334 metadata_output = self.metadata_output.deserialize_write_edge(
335 key, acc.METADATA_OUTPUT_CONNECTION_NAME, dataset_type_keys
336 )
337 dimensions: frozenset[str] | DimensionGraph
338 if universe is not None:
339 dimensions = universe.extract(self.dimensions)
340 else:
341 dimensions = frozenset(self.dimensions)
342 return TaskNode(
343 key=key,
344 init=init,
345 inputs=inputs,
346 prerequisite_inputs=prerequisite_inputs,
347 outputs=outputs,
348 log_output=log_output,
349 metadata_output=metadata_output,
350 dimensions=dimensions,
351 )
354class SerializedDatasetTypeNode(_BaseModelCompat):
355 """Struct used to represent a serialized `DatasetTypeNode` in a
356 `PipelineGraph`.
358 Unresolved dataset types are serialized as instances with at most the
359 `index` attribute set, and are typically converted to JSON with pydantic's
360 ``exclude_defaults=True`` option to keep this compact.
362 The dataset typename is serialized by the context in which a
363 `SerializedDatasetTypeNode` appears (e.g. the keys of the nested dictionary
364 in which it serves as the value type).
365 """
367 dimensions: list[str] | None = None
368 """Dimensions of the dataset type."""
370 storage_class: str | None = None
371 """Name of the storage class."""
373 is_calibration: bool = False
374 """Whether this dataset type is a calibration."""
376 is_initial_query_constraint: bool = False
377 """Whether this dataset type should be a query constraint during
378 `QuantumGraph` generation."""
380 is_prerequisite: bool = False
381 """Whether datasets of this dataset type must exist in the input collection
382 before `QuantumGraph` generation."""
384 index: int | None = None
385 """The index of this node in the sorted sequence of `PipelineGraph`.
387 This is `None` if the `PipelineGraph` was not sorted when it was
388 serialized.
389 """
391 @classmethod
392 def serialize(cls, target: DatasetTypeNode | None) -> SerializedDatasetTypeNode:
393 """Transform a `DatasetTypeNode` to a `SerializedDatasetTypeNode`."""
394 if target is None:
395 return cls.model_construct()
396 return cls.model_construct(
397 dimensions=list(target.dataset_type.dimensions.names),
398 storage_class=target.dataset_type.storageClass_name,
399 is_calibration=target.dataset_type.isCalibration(),
400 is_initial_query_constraint=target.is_initial_query_constraint,
401 is_prerequisite=target.is_prerequisite,
402 )
404 def deserialize(
405 self, key: NodeKey, xgraph: networkx.MultiDiGraph, universe: DimensionUniverse | None
406 ) -> DatasetTypeNode | None:
407 """Transform a `SerializedDatasetTypeNode` to a `DatasetTypeNode`."""
408 if self.dimensions is not None:
409 dataset_type = DatasetType(
410 key.name,
411 expect_not_none(
412 self.dimensions,
413 f"Serialized dataset type {key.name!r} has no dimensions.",
414 ),
415 storageClass=expect_not_none(
416 self.storage_class,
417 f"Serialized dataset type {key.name!r} has no storage class.",
418 ),
419 isCalibration=self.is_calibration,
420 universe=expect_not_none(
421 universe,
422 f"Serialized dataset type {key.name!r} has dimensions, "
423 "but no dimension universe was stored.",
424 ),
425 )
426 producer: str | None = None
427 producing_edge: WriteEdge | None = None
428 for _, _, producing_edge in xgraph.in_edges(key, data="instance"):
429 assert producing_edge is not None, "Should only be None if we never loop."
430 if producer is not None: 430 ↛ 431line 430 didn't jump to line 431, because the condition on line 430 was never true
431 raise PipelineGraphReadError(
432 f"Serialized dataset type {key.name!r} is produced by both "
433 f"{producing_edge.task_label!r} and {producer!r} in resolved graph."
434 )
435 producer = producing_edge.task_label
436 consuming_edges = tuple(
437 consuming_edge for _, _, consuming_edge in xgraph.in_edges(key, data="instance")
438 )
439 return DatasetTypeNode(
440 dataset_type=dataset_type,
441 is_prerequisite=self.is_prerequisite,
442 is_initial_query_constraint=self.is_initial_query_constraint,
443 producing_edge=producing_edge,
444 consuming_edges=consuming_edges,
445 )
446 return None
449class SerializedTaskSubset(_BaseModelCompat):
450 """Struct used to represent a serialized `TaskSubset` in a `PipelineGraph`.
452 The subsetlabel is serialized by the context in which a
453 `SerializedDatasetTypeNode` appears (e.g. the keys of the nested dictionary
454 in which it serves as the value type).
455 """
457 description: str
458 """Description of the subset."""
460 tasks: list[str]
461 """Labels of tasks in the subset, sorted lexicographically for
462 determinism.
463 """
465 @classmethod
466 def serialize(cls, target: TaskSubset) -> SerializedTaskSubset:
467 """Transform a `TaskSubset` into a `SerializedTaskSubset`."""
468 return cls.model_construct(description=target._description, tasks=list(sorted(target)))
470 def deserialize_task_subset(self, label: str, xgraph: networkx.MultiDiGraph) -> TaskSubset:
471 """Transform a `SerializedTaskSubset` into a `TaskSubset`."""
472 members = set(self.tasks)
473 return TaskSubset(xgraph, label, members, self.description)
476class SerializedPipelineGraph(_BaseModelCompat):
477 """Struct used to represent a serialized `PipelineGraph`."""
479 version: str = ".".join(str(v) for v in _IO_VERSION_INFO)
480 """Serialization version."""
482 description: str
483 """Human-readable description of the pipeline."""
485 tasks: dict[str, SerializedTaskNode] = pydantic.Field(default_factory=dict)
486 """Mapping of serialized tasks, keyed by label."""
488 dataset_types: dict[str, SerializedDatasetTypeNode] = pydantic.Field(default_factory=dict)
489 """Mapping of serialized dataset types, keyed by parent dataset type name.
490 """
492 task_subsets: dict[str, SerializedTaskSubset] = pydantic.Field(default_factory=dict)
493 """Mapping of task subsets, keyed by subset label."""
495 dimensions: dict[str, Any] | None = None
496 """Dimension universe configuration."""
498 data_id: dict[str, Any] = pydantic.Field(default_factory=dict)
499 """Data ID that constrains all quanta generated from this pipeline."""
501 @classmethod
502 def serialize(cls, target: PipelineGraph) -> SerializedPipelineGraph:
503 """Transform a `PipelineGraph` into a `SerializedPipelineGraph`."""
504 result = SerializedPipelineGraph.model_construct(
505 description=target.description,
506 tasks={label: SerializedTaskNode.serialize(node) for label, node in target.tasks.items()},
507 dataset_types={
508 name: SerializedDatasetTypeNode().serialize(target.dataset_types.get_if_resolved(name))
509 for name in target.dataset_types
510 },
511 task_subsets={
512 label: SerializedTaskSubset.serialize(subset) for label, subset in target.task_subsets.items()
513 },
514 dimensions=target.universe.dimensionConfig.toDict() if target.universe is not None else None,
515 data_id=target._raw_data_id,
516 )
517 if target._sorted_keys:
518 for index, node_key in enumerate(target._sorted_keys):
519 match node_key.node_type:
520 case NodeType.TASK:
521 result.tasks[node_key.name].index = index
522 case NodeType.DATASET_TYPE:
523 result.dataset_types[node_key.name].index = index
524 case NodeType.TASK_INIT: 524 ↛ 518line 524 didn't jump to line 518, because the pattern on line 524 never matched
525 result.tasks[node_key.name].init.index = index
526 return result
528 def deserialize(
529 self,
530 import_mode: TaskImportMode,
531 ) -> PipelineGraph:
532 """Transform a `SerializedPipelineGraph` into a `PipelineGraph`."""
533 universe: DimensionUniverse | None = None
534 if self.dimensions is not None:
535 universe = DimensionUniverse(
536 config=DimensionConfig(
537 expect_not_none(
538 self.dimensions,
539 "Serialized pipeline graph has not been resolved; "
540 "load it is a MutablePipelineGraph instead.",
541 )
542 )
543 )
544 xgraph = networkx.MultiDiGraph()
545 sort_index_map: dict[int, NodeKey] = {}
546 # Save the dataset type keys after the first time we make them - these
547 # may be tiny objects, but it's still to have only one copy of each
548 # value floating around the graph.
549 dataset_type_keys: dict[str, NodeKey] = {}
550 for dataset_type_name, serialized_dataset_type in self.dataset_types.items():
551 dataset_type_key = NodeKey(NodeType.DATASET_TYPE, dataset_type_name)
552 # We intentionally don't attach a DatasetTypeNode instance here
553 # yet, since we need edges to do that and those are saved with
554 # the tasks.
555 xgraph.add_node(dataset_type_key, bipartite=NodeType.DATASET_TYPE.value)
556 if serialized_dataset_type.index is not None:
557 sort_index_map[serialized_dataset_type.index] = dataset_type_key
558 dataset_type_keys[dataset_type_name] = dataset_type_key
559 for task_label, serialized_task in self.tasks.items():
560 task_key = NodeKey(NodeType.TASK, task_label)
561 task_init_key = NodeKey(NodeType.TASK_INIT, task_label)
562 task_node = serialized_task.deserialize(task_key, task_init_key, dataset_type_keys, universe)
563 if serialized_task.index is not None:
564 sort_index_map[serialized_task.index] = task_key
565 if serialized_task.init.index is not None:
566 sort_index_map[serialized_task.init.index] = task_init_key
567 xgraph.add_node(task_key, instance=task_node, bipartite=NodeType.TASK.bipartite)
568 xgraph.add_node(task_init_key, instance=task_node.init, bipartite=NodeType.TASK_INIT.bipartite)
569 xgraph.add_edge(task_init_key, task_key, Edge.INIT_TO_TASK_NAME, instance=None)
570 for read_edge in task_node.init.iter_all_inputs():
571 xgraph.add_edge(
572 read_edge.dataset_type_key,
573 read_edge.task_key,
574 read_edge.connection_name,
575 instance=read_edge,
576 )
577 for write_edge in task_node.init.iter_all_outputs():
578 xgraph.add_edge(
579 write_edge.task_key,
580 write_edge.dataset_type_key,
581 write_edge.connection_name,
582 instance=write_edge,
583 )
584 for read_edge in task_node.iter_all_inputs():
585 xgraph.add_edge(
586 read_edge.dataset_type_key,
587 read_edge.task_key,
588 read_edge.connection_name,
589 instance=read_edge,
590 )
591 for write_edge in task_node.iter_all_outputs():
592 xgraph.add_edge(
593 write_edge.task_key,
594 write_edge.dataset_type_key,
595 write_edge.connection_name,
596 instance=write_edge,
597 )
598 # Iterate over dataset types again to add instances.
599 for dataset_type_name, serialized_dataset_type in self.dataset_types.items():
600 dataset_type_key = dataset_type_keys[dataset_type_name]
601 xgraph.nodes[dataset_type_key]["instance"] = serialized_dataset_type.deserialize(
602 dataset_type_key, xgraph, universe
603 )
604 result = PipelineGraph.__new__(PipelineGraph)
605 result._init_from_args(
606 xgraph,
607 sorted_keys=[sort_index_map[i] for i in range(len(xgraph))] if sort_index_map else None,
608 task_subsets={
609 subset_label: serialized_subset.deserialize_task_subset(subset_label, xgraph)
610 for subset_label, serialized_subset in self.task_subsets.items()
611 },
612 description=self.description,
613 universe=universe,
614 data_id=self.data_id,
615 )
616 result._import_and_configure(import_mode)
617 return result