Coverage for python/lsst/pipe/base/pipeline_graph/io.py: 98%
202 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 09:31 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 09:31 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = (
30 "expect_not_none",
31 "SerializedEdge",
32 "SerializedTaskInitNode",
33 "SerializedTaskNode",
34 "SerializedDatasetTypeNode",
35 "SerializedTaskSubset",
36 "SerializedPipelineGraph",
37)
39from collections.abc import Mapping
40from typing import Any, TypeVar
42import networkx
43import pydantic
44from lsst.daf.butler import DatasetType, DimensionConfig, DimensionGraph, DimensionUniverse
45from lsst.daf.butler._compat import _BaseModelCompat
47from .. import automatic_connection_constants as acc
48from ._dataset_types import DatasetTypeNode
49from ._edges import Edge, ReadEdge, WriteEdge
50from ._exceptions import PipelineGraphReadError
51from ._nodes import NodeKey, NodeType
52from ._pipeline_graph import PipelineGraph
53from ._task_subsets import TaskSubset
54from ._tasks import TaskImportMode, TaskInitNode, TaskNode
56_U = TypeVar("_U")
58_IO_VERSION_INFO = (0, 0, 1)
59"""Version tuple embedded in saved PipelineGraphs.
60"""
63def expect_not_none(value: _U | None, msg: str) -> _U:
64 """Check that a value is not `None` and return it.
66 Parameters
67 ----------
68 value
69 Value to check
70 msg
71 Error message for the case where ``value is None``.
73 Returns
74 -------
75 value
76 Value, guaranteed not to be `None`.
78 Raises
79 ------
80 PipelineGraphReadError
81 Raised with ``msg`` if ``value is None``.
82 """
83 if value is None: 83 ↛ 84line 83 didn't jump to line 84, because the condition on line 83 was never true
84 raise PipelineGraphReadError(msg)
85 return value
88class SerializedEdge(_BaseModelCompat):
89 """Struct used to represent a serialized `Edge` in a `PipelineGraph`.
91 All `ReadEdge` and `WriteEdge` state not included here is instead
92 effectively serialized by the context in which a `SerializedEdge` appears
93 (e.g. the keys of the nested dictionaries in which it serves as the value
94 type).
95 """
97 dataset_type_name: str
98 """Full dataset type name (including component)."""
100 storage_class: str
101 """Name of the storage class."""
103 raw_dimensions: list[str]
104 """Raw dimensions of the dataset type from the task connections."""
106 is_calibration: bool = False
107 """Whether this dataset type can be included in
108 `~lsst.daf.butler.CollectionType.CALIBRATION` collections."""
110 defer_query_constraint: bool = False
111 """If `True`, by default do not include this dataset type's existence as a
112 constraint on the initial data ID query in QuantumGraph generation."""
114 @classmethod
115 def serialize(cls, target: Edge) -> SerializedEdge:
116 """Transform an `Edge` to a `SerializedEdge`."""
117 return SerializedEdge.model_construct(
118 storage_class=target.storage_class_name,
119 dataset_type_name=target.dataset_type_name,
120 raw_dimensions=sorted(target.raw_dimensions),
121 is_calibration=target.is_calibration,
122 defer_query_constraint=getattr(target, "defer_query_constraint", False),
123 )
125 def deserialize_read_edge(
126 self,
127 task_key: NodeKey,
128 connection_name: str,
129 dataset_type_keys: Mapping[str, NodeKey],
130 is_prerequisite: bool = False,
131 ) -> ReadEdge:
132 """Transform a `SerializedEdge` to a `ReadEdge`."""
133 parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(self.dataset_type_name)
134 return ReadEdge(
135 dataset_type_key=dataset_type_keys[parent_dataset_type_name],
136 task_key=task_key,
137 storage_class_name=self.storage_class,
138 is_prerequisite=is_prerequisite,
139 component=component,
140 connection_name=connection_name,
141 is_calibration=self.is_calibration,
142 defer_query_constraint=self.defer_query_constraint,
143 raw_dimensions=frozenset(self.raw_dimensions),
144 )
146 def deserialize_write_edge(
147 self,
148 task_key: NodeKey,
149 connection_name: str,
150 dataset_type_keys: Mapping[str, NodeKey],
151 ) -> WriteEdge:
152 """Transform a `SerializedEdge` to a `WriteEdge`."""
153 return WriteEdge(
154 task_key=task_key,
155 dataset_type_key=dataset_type_keys[self.dataset_type_name],
156 storage_class_name=self.storage_class,
157 connection_name=connection_name,
158 is_calibration=self.is_calibration,
159 raw_dimensions=frozenset(self.raw_dimensions),
160 )
163class SerializedTaskInitNode(_BaseModelCompat):
164 """Struct used to represent a serialized `TaskInitNode` in a
165 `PipelineGraph`.
167 The task label is serialized by the context in which a
168 `SerializedTaskInitNode` appears (e.g. the keys of the nested dictionary
169 in which it serves as the value type), and the task class name and config
170 string are save with the corresponding `SerializedTaskNode`.
171 """
173 inputs: dict[str, SerializedEdge]
174 """Mapping of serialized init-input edges, keyed by connection name."""
176 outputs: dict[str, SerializedEdge]
177 """Mapping of serialized init-output edges, keyed by connection name."""
179 config_output: SerializedEdge
180 """The serialized config init-output edge."""
182 index: int | None = None
183 """The index of this node in the sorted sequence of `PipelineGraph`.
185 This is `None` if the `PipelineGraph` was not sorted when it was
186 serialized.
187 """
189 @classmethod
190 def serialize(cls, target: TaskInitNode) -> SerializedTaskInitNode:
191 """Transform a `TaskInitNode` to a `SerializedTaskInitNode`."""
192 return cls.model_construct(
193 inputs={
194 connection_name: SerializedEdge.serialize(edge)
195 for connection_name, edge in sorted(target.inputs.items())
196 },
197 outputs={
198 connection_name: SerializedEdge.serialize(edge)
199 for connection_name, edge in sorted(target.outputs.items())
200 },
201 config_output=SerializedEdge.serialize(target.config_output),
202 )
204 def deserialize(
205 self,
206 key: NodeKey,
207 task_class_name: str,
208 config_str: str,
209 dataset_type_keys: Mapping[str, NodeKey],
210 ) -> TaskInitNode:
211 """Transform a `SerializedTaskInitNode` to a `TaskInitNode`."""
212 return TaskInitNode(
213 key,
214 inputs={
215 connection_name: serialized_edge.deserialize_read_edge(
216 key, connection_name, dataset_type_keys
217 )
218 for connection_name, serialized_edge in self.inputs.items()
219 },
220 outputs={
221 connection_name: serialized_edge.deserialize_write_edge(
222 key, connection_name, dataset_type_keys
223 )
224 for connection_name, serialized_edge in self.outputs.items()
225 },
226 config_output=self.config_output.deserialize_write_edge(
227 key, acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, dataset_type_keys
228 ),
229 task_class_name=task_class_name,
230 config_str=config_str,
231 )
234class SerializedTaskNode(_BaseModelCompat):
235 """Struct used to represent a serialized `TaskNode` in a `PipelineGraph`.
237 The task label is serialized by the context in which a
238 `SerializedTaskNode` appears (e.g. the keys of the nested dictionary in
239 which it serves as the value type).
240 """
242 task_class: str
243 """Fully-qualified name of the task class."""
245 init: SerializedTaskInitNode
246 """Serialized task initialization node."""
248 config_str: str
249 """Configuration for the task as a string of override statements."""
251 prerequisite_inputs: dict[str, SerializedEdge]
252 """Mapping of serialized prerequisiste input edges, keyed by connection
253 name.
254 """
256 inputs: dict[str, SerializedEdge]
257 """Mapping of serialized input edges, keyed by connection name."""
259 outputs: dict[str, SerializedEdge]
260 """Mapping of serialized output edges, keyed by connection name."""
262 metadata_output: SerializedEdge
263 """The serialized metadata output edge."""
265 dimensions: list[str]
266 """The task's dimensions, if they were resolved."""
268 log_output: SerializedEdge | None = None
269 """The serialized log output edge."""
271 index: int | None = None
272 """The index of this node in the sorted sequence of `PipelineGraph`.
274 This is `None` if the `PipelineGraph` was not sorted when it was
275 serialized.
276 """
278 @classmethod
279 def serialize(cls, target: TaskNode) -> SerializedTaskNode:
280 """Transform a `TaskNode` to a `SerializedTaskNode`."""
281 return cls.model_construct(
282 task_class=target.task_class_name,
283 init=SerializedTaskInitNode.serialize(target.init),
284 config_str=target.get_config_str(),
285 dimensions=list(target.raw_dimensions),
286 prerequisite_inputs={
287 connection_name: SerializedEdge.serialize(edge)
288 for connection_name, edge in sorted(target.prerequisite_inputs.items())
289 },
290 inputs={
291 connection_name: SerializedEdge.serialize(edge)
292 for connection_name, edge in sorted(target.inputs.items())
293 },
294 outputs={
295 connection_name: SerializedEdge.serialize(edge)
296 for connection_name, edge in sorted(target.outputs.items())
297 },
298 metadata_output=SerializedEdge.serialize(target.metadata_output),
299 log_output=(
300 SerializedEdge.serialize(target.log_output) if target.log_output is not None else None
301 ),
302 )
304 def deserialize(
305 self,
306 key: NodeKey,
307 init_key: NodeKey,
308 dataset_type_keys: Mapping[str, NodeKey],
309 universe: DimensionUniverse | None,
310 ) -> TaskNode:
311 """Transform a `SerializedTaskNode` to a `TaskNode`."""
312 init = self.init.deserialize(
313 init_key,
314 task_class_name=self.task_class,
315 config_str=expect_not_none(
316 self.config_str, f"No serialized config file for task with label {key.name!r}."
317 ),
318 dataset_type_keys=dataset_type_keys,
319 )
320 inputs = {
321 connection_name: serialized_edge.deserialize_read_edge(key, connection_name, dataset_type_keys)
322 for connection_name, serialized_edge in self.inputs.items()
323 }
324 prerequisite_inputs = {
325 connection_name: serialized_edge.deserialize_read_edge(
326 key, connection_name, dataset_type_keys, is_prerequisite=True
327 )
328 for connection_name, serialized_edge in self.prerequisite_inputs.items()
329 }
330 outputs = {
331 connection_name: serialized_edge.deserialize_write_edge(key, connection_name, dataset_type_keys)
332 for connection_name, serialized_edge in self.outputs.items()
333 }
334 if (serialized_log_output := self.log_output) is not None: 334 ↛ 339line 334 didn't jump to line 339, because the condition on line 334 was never false
335 log_output = serialized_log_output.deserialize_write_edge(
336 key, acc.LOG_OUTPUT_CONNECTION_NAME, dataset_type_keys
337 )
338 else:
339 log_output = None
340 metadata_output = self.metadata_output.deserialize_write_edge(
341 key, acc.METADATA_OUTPUT_CONNECTION_NAME, dataset_type_keys
342 )
343 dimensions: frozenset[str] | DimensionGraph
344 if universe is not None:
345 dimensions = universe.extract(self.dimensions)
346 else:
347 dimensions = frozenset(self.dimensions)
348 return TaskNode(
349 key=key,
350 init=init,
351 inputs=inputs,
352 prerequisite_inputs=prerequisite_inputs,
353 outputs=outputs,
354 log_output=log_output,
355 metadata_output=metadata_output,
356 dimensions=dimensions,
357 )
360class SerializedDatasetTypeNode(_BaseModelCompat):
361 """Struct used to represent a serialized `DatasetTypeNode` in a
362 `PipelineGraph`.
364 Unresolved dataset types are serialized as instances with at most the
365 `index` attribute set, and are typically converted to JSON with pydantic's
366 ``exclude_defaults=True`` option to keep this compact.
368 The dataset typename is serialized by the context in which a
369 `SerializedDatasetTypeNode` appears (e.g. the keys of the nested dictionary
370 in which it serves as the value type).
371 """
373 dimensions: list[str] | None = None
374 """Dimensions of the dataset type."""
376 storage_class: str | None = None
377 """Name of the storage class."""
379 is_calibration: bool = False
380 """Whether this dataset type is a calibration."""
382 is_initial_query_constraint: bool = False
383 """Whether this dataset type should be a query constraint during
384 `QuantumGraph` generation."""
386 is_prerequisite: bool = False
387 """Whether datasets of this dataset type must exist in the input collection
388 before `QuantumGraph` generation."""
390 index: int | None = None
391 """The index of this node in the sorted sequence of `PipelineGraph`.
393 This is `None` if the `PipelineGraph` was not sorted when it was
394 serialized.
395 """
397 @classmethod
398 def serialize(cls, target: DatasetTypeNode | None) -> SerializedDatasetTypeNode:
399 """Transform a `DatasetTypeNode` to a `SerializedDatasetTypeNode`."""
400 if target is None:
401 return cls.model_construct()
402 return cls.model_construct(
403 dimensions=list(target.dataset_type.dimensions.names),
404 storage_class=target.dataset_type.storageClass_name,
405 is_calibration=target.dataset_type.isCalibration(),
406 is_initial_query_constraint=target.is_initial_query_constraint,
407 is_prerequisite=target.is_prerequisite,
408 )
410 def deserialize(
411 self, key: NodeKey, xgraph: networkx.MultiDiGraph, universe: DimensionUniverse | None
412 ) -> DatasetTypeNode | None:
413 """Transform a `SerializedDatasetTypeNode` to a `DatasetTypeNode`."""
414 if self.dimensions is not None:
415 dataset_type = DatasetType(
416 key.name,
417 expect_not_none(
418 self.dimensions,
419 f"Serialized dataset type {key.name!r} has no dimensions.",
420 ),
421 storageClass=expect_not_none(
422 self.storage_class,
423 f"Serialized dataset type {key.name!r} has no storage class.",
424 ),
425 isCalibration=self.is_calibration,
426 universe=expect_not_none(
427 universe,
428 f"Serialized dataset type {key.name!r} has dimensions, "
429 "but no dimension universe was stored.",
430 ),
431 )
432 producer: str | None = None
433 producing_edge: WriteEdge | None = None
434 for _, _, producing_edge in xgraph.in_edges(key, data="instance"):
435 assert producing_edge is not None, "Should only be None if we never loop."
436 if producer is not None: 436 ↛ 437line 436 didn't jump to line 437, because the condition on line 436 was never true
437 raise PipelineGraphReadError(
438 f"Serialized dataset type {key.name!r} is produced by both "
439 f"{producing_edge.task_label!r} and {producer!r} in resolved graph."
440 )
441 producer = producing_edge.task_label
442 consuming_edges = tuple(
443 consuming_edge for _, _, consuming_edge in xgraph.in_edges(key, data="instance")
444 )
445 return DatasetTypeNode(
446 dataset_type=dataset_type,
447 is_prerequisite=self.is_prerequisite,
448 is_initial_query_constraint=self.is_initial_query_constraint,
449 producing_edge=producing_edge,
450 consuming_edges=consuming_edges,
451 )
452 return None
455class SerializedTaskSubset(_BaseModelCompat):
456 """Struct used to represent a serialized `TaskSubset` in a `PipelineGraph`.
458 The subsetlabel is serialized by the context in which a
459 `SerializedDatasetTypeNode` appears (e.g. the keys of the nested dictionary
460 in which it serves as the value type).
461 """
463 description: str
464 """Description of the subset."""
466 tasks: list[str]
467 """Labels of tasks in the subset, sorted lexicographically for
468 determinism.
469 """
471 @classmethod
472 def serialize(cls, target: TaskSubset) -> SerializedTaskSubset:
473 """Transform a `TaskSubset` into a `SerializedTaskSubset`."""
474 return cls.model_construct(description=target._description, tasks=list(sorted(target)))
476 def deserialize_task_subset(self, label: str, xgraph: networkx.MultiDiGraph) -> TaskSubset:
477 """Transform a `SerializedTaskSubset` into a `TaskSubset`."""
478 members = set(self.tasks)
479 return TaskSubset(xgraph, label, members, self.description)
482class SerializedPipelineGraph(_BaseModelCompat):
483 """Struct used to represent a serialized `PipelineGraph`."""
485 version: str = ".".join(str(v) for v in _IO_VERSION_INFO)
486 """Serialization version."""
488 description: str
489 """Human-readable description of the pipeline."""
491 tasks: dict[str, SerializedTaskNode] = pydantic.Field(default_factory=dict)
492 """Mapping of serialized tasks, keyed by label."""
494 dataset_types: dict[str, SerializedDatasetTypeNode] = pydantic.Field(default_factory=dict)
495 """Mapping of serialized dataset types, keyed by parent dataset type name.
496 """
498 task_subsets: dict[str, SerializedTaskSubset] = pydantic.Field(default_factory=dict)
499 """Mapping of task subsets, keyed by subset label."""
501 dimensions: dict[str, Any] | None = None
502 """Dimension universe configuration."""
504 data_id: dict[str, Any] = pydantic.Field(default_factory=dict)
505 """Data ID that constrains all quanta generated from this pipeline."""
507 @classmethod
508 def serialize(cls, target: PipelineGraph) -> SerializedPipelineGraph:
509 """Transform a `PipelineGraph` into a `SerializedPipelineGraph`."""
510 result = SerializedPipelineGraph.model_construct(
511 description=target.description,
512 tasks={label: SerializedTaskNode.serialize(node) for label, node in target.tasks.items()},
513 dataset_types={
514 name: SerializedDatasetTypeNode().serialize(target.dataset_types.get_if_resolved(name))
515 for name in target.dataset_types
516 },
517 task_subsets={
518 label: SerializedTaskSubset.serialize(subset) for label, subset in target.task_subsets.items()
519 },
520 dimensions=target.universe.dimensionConfig.toDict() if target.universe is not None else None,
521 data_id=target._raw_data_id,
522 )
523 if target._sorted_keys:
524 for index, node_key in enumerate(target._sorted_keys):
525 match node_key.node_type:
526 case NodeType.TASK:
527 result.tasks[node_key.name].index = index
528 case NodeType.DATASET_TYPE:
529 result.dataset_types[node_key.name].index = index
530 case NodeType.TASK_INIT: 530 ↛ 524line 530 didn't jump to line 524, because the pattern on line 530 never matched
531 result.tasks[node_key.name].init.index = index
532 return result
534 def deserialize(
535 self,
536 import_mode: TaskImportMode,
537 ) -> PipelineGraph:
538 """Transform a `SerializedPipelineGraph` into a `PipelineGraph`."""
539 universe: DimensionUniverse | None = None
540 if self.dimensions is not None:
541 universe = DimensionUniverse(
542 config=DimensionConfig(
543 expect_not_none(
544 self.dimensions,
545 "Serialized pipeline graph has not been resolved; "
546 "load it is a MutablePipelineGraph instead.",
547 )
548 )
549 )
550 xgraph = networkx.MultiDiGraph()
551 sort_index_map: dict[int, NodeKey] = {}
552 # Save the dataset type keys after the first time we make them - these
553 # may be tiny objects, but it's still to have only one copy of each
554 # value floating around the graph.
555 dataset_type_keys: dict[str, NodeKey] = {}
556 for dataset_type_name, serialized_dataset_type in self.dataset_types.items():
557 dataset_type_key = NodeKey(NodeType.DATASET_TYPE, dataset_type_name)
558 # We intentionally don't attach a DatasetTypeNode instance here
559 # yet, since we need edges to do that and those are saved with
560 # the tasks.
561 xgraph.add_node(dataset_type_key, bipartite=NodeType.DATASET_TYPE.value)
562 if serialized_dataset_type.index is not None:
563 sort_index_map[serialized_dataset_type.index] = dataset_type_key
564 dataset_type_keys[dataset_type_name] = dataset_type_key
565 for task_label, serialized_task in self.tasks.items():
566 task_key = NodeKey(NodeType.TASK, task_label)
567 task_init_key = NodeKey(NodeType.TASK_INIT, task_label)
568 task_node = serialized_task.deserialize(task_key, task_init_key, dataset_type_keys, universe)
569 if serialized_task.index is not None:
570 sort_index_map[serialized_task.index] = task_key
571 if serialized_task.init.index is not None:
572 sort_index_map[serialized_task.init.index] = task_init_key
573 xgraph.add_node(task_key, instance=task_node, bipartite=NodeType.TASK.bipartite)
574 xgraph.add_node(task_init_key, instance=task_node.init, bipartite=NodeType.TASK_INIT.bipartite)
575 xgraph.add_edge(task_init_key, task_key, Edge.INIT_TO_TASK_NAME, instance=None)
576 for read_edge in task_node.init.iter_all_inputs():
577 xgraph.add_edge(
578 read_edge.dataset_type_key,
579 read_edge.task_key,
580 read_edge.connection_name,
581 instance=read_edge,
582 )
583 for write_edge in task_node.init.iter_all_outputs():
584 xgraph.add_edge(
585 write_edge.task_key,
586 write_edge.dataset_type_key,
587 write_edge.connection_name,
588 instance=write_edge,
589 )
590 for read_edge in task_node.iter_all_inputs():
591 xgraph.add_edge(
592 read_edge.dataset_type_key,
593 read_edge.task_key,
594 read_edge.connection_name,
595 instance=read_edge,
596 )
597 for write_edge in task_node.iter_all_outputs():
598 xgraph.add_edge(
599 write_edge.task_key,
600 write_edge.dataset_type_key,
601 write_edge.connection_name,
602 instance=write_edge,
603 )
604 # Iterate over dataset types again to add instances.
605 for dataset_type_name, serialized_dataset_type in self.dataset_types.items():
606 dataset_type_key = dataset_type_keys[dataset_type_name]
607 xgraph.nodes[dataset_type_key]["instance"] = serialized_dataset_type.deserialize(
608 dataset_type_key, xgraph, universe
609 )
610 result = PipelineGraph.__new__(PipelineGraph)
611 result._init_from_args(
612 xgraph,
613 sorted_keys=[sort_index_map[i] for i in range(len(xgraph))] if sort_index_map else None,
614 task_subsets={
615 subset_label: serialized_subset.deserialize_task_subset(subset_label, xgraph)
616 for subset_label, serialized_subset in self.task_subsets.items()
617 },
618 description=self.description,
619 universe=universe,
620 data_id=self.data_id,
621 )
622 result._import_and_configure(import_mode)
623 return result