Coverage for python/lsst/pipe/base/graph/_implDetails.py: 29%
50 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-31 09:39 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-31 09:39 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("_DatasetTracker", "DatasetTypeName")
25from collections import defaultdict
26from typing import NewType
28import networkx as nx
30from ..pipeline import TaskDef
32# NewTypes
33DatasetTypeName = NewType("DatasetTypeName", str)
36class _DatasetTracker:
37 r"""A container for tracking the relationships between tasks and dataset
38 types.
40 Prameters
41 ---------
42 createInverse : bool
43 When adding a key associated with a producer or consumer, also create
44 and inverse mapping that allows looking up all the keys associated with
45 some value. Defaults to False.
46 """
48 def __init__(self, createInverse: bool = False):
49 self._producers: dict[DatasetTypeName, TaskDef] = {}
50 self._consumers: defaultdict[DatasetTypeName, set[TaskDef]] = defaultdict(set)
51 self._createInverse = createInverse
52 if self._createInverse:
53 self._itemsDict: defaultdict[TaskDef, set[DatasetTypeName]] = defaultdict(set)
55 def addProducer(self, key: DatasetTypeName, value: TaskDef) -> None:
56 """Add a key which is produced by some value.
58 Parameters
59 ----------
60 key : `~typing.TypeVar`
61 The type to track.
62 value : `~typing.TypeVar`
63 The type associated with the production of the key.
65 Raises
66 ------
67 ValueError
68 Raised if key is already declared to be produced by another value.
69 """
70 if (existing := self._producers.get(key)) is not None and existing != value:
71 raise ValueError(f"Only one node is allowed to produce {key}, the current producer is {existing}")
72 self._producers[key] = value
73 if self._createInverse:
74 self._itemsDict[value].add(key)
76 def addConsumer(self, key: DatasetTypeName, value: TaskDef) -> None:
77 """Add a key which is consumed by some value.
79 Parameters
80 ----------
81 key : `~typing.TypeVar`
82 The type to track.
83 value : `~typing.TypeVar`
84 The type associated with the consumption of the key.
85 """
86 self._consumers[key].add(value)
87 if self._createInverse:
88 self._itemsDict[value].add(key)
90 def getConsumers(self, key: DatasetTypeName) -> set[TaskDef]:
91 """Return all values associated with the consumption of the supplied
92 key.
94 Parameters
95 ----------
96 key : `~typing.TypeVar`
97 The type which has been tracked in the `_DatasetTracker`.
98 """
99 return self._consumers.get(key, set())
101 def getProducer(self, key: DatasetTypeName) -> TaskDef | None:
102 """Return the value associated with the consumption of the supplied
103 key.
105 Parameters
106 ----------
107 key : `~typing.TypeVar`
108 The type which has been tracked in the `_DatasetTracker`.
109 """
110 # This tracker may have had all nodes associated with a key removed
111 # and if there are no refs (empty set) should return None
112 return producer if (producer := self._producers.get(key)) else None
114 def getAll(self, key: DatasetTypeName) -> set[TaskDef]:
115 """Return all consumers and the producer associated with the the
116 supplied key.
118 Parameters
119 ----------
120 key : `~typing.TypeVar`
121 The type which has been tracked in the `_DatasetTracker`.
122 """
123 return self.getConsumers(key).union(x for x in (self.getProducer(key),) if x is not None)
125 @property
126 def inverse(self) -> defaultdict[TaskDef, set[DatasetTypeName]] | None:
127 """Return the inverse mapping if class was instantiated to create an
128 inverse, else return None.
129 """
130 return self._itemsDict if self._createInverse else None
132 def makeNetworkXGraph(self) -> nx.DiGraph:
133 """Create a NetworkX graph out of all the contained keys, using the
134 relations of producer and consumers to create the edges.
136 Returns
137 -------
138 graph : `networkx.DiGraph`
139 The graph created out of the supplied keys and their relations.
140 """
141 graph = nx.DiGraph()
142 for entry in self._producers.keys() | self._consumers.keys():
143 producer = self.getProducer(entry)
144 consumers = self.getConsumers(entry)
145 # This block is for tasks that consume existing inputs
146 if producer is None and consumers:
147 for consumer in consumers:
148 graph.add_node(consumer)
149 # This block is for tasks that produce output that is not consumed
150 # in this graph
151 elif producer is not None and not consumers:
152 graph.add_node(producer)
153 # all other connections
154 else:
155 for consumer in consumers:
156 graph.add_edge(producer, consumer)
157 return graph
159 def keys(self) -> set[DatasetTypeName]:
160 """Return all tracked keys."""
161 return self._producers.keys() | self._consumers.keys()
163 def __contains__(self, key: DatasetTypeName) -> bool:
164 """Check if a key is in the `_DatasetTracker`.
166 Parameters
167 ----------
168 key : `~typing.TypeVar`
169 The key to check.
171 Returns
172 -------
173 contains : `bool`
174 Boolean of the presence of the supplied key.
175 """
176 return key in self._producers or key in self._consumers