Coverage for python/lsst/pipe/base/graph/_implDetails.py: 15%
131 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-29 10:30 +0000
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-29 10:30 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("_DatasetTracker", "DatasetTypeName", "_pruner")
25from collections import defaultdict
26from itertools import chain
27from typing import DefaultDict, Dict, Generic, Iterable, List, NewType, Optional, Set, TypeVar
29import networkx as nx
30from lsst.daf.butler import DatasetRef, DatasetType, NamedKeyDict, Quantum
31from lsst.pipe.base.connections import AdjustQuantumHelper
33from .._status import NoWorkFound
34from ..pipeline import TaskDef
35from .quantumNode import QuantumNode
37# NewTypes
38DatasetTypeName = NewType("DatasetTypeName", str)
40# Generic type parameters
41_T = TypeVar("_T", DatasetTypeName, DatasetRef)
42_U = TypeVar("_U", TaskDef, QuantumNode)
45class _DatasetTracker(Generic[_T, _U]):
46 r"""This is a generic container for tracking keys which are produced or
47 consumed by some value. In the context of a QuantumGraph, keys may be
48 `~lsst.daf.butler.DatasetRef`\ s and the values would be Quanta that either
49 produce or consume those `~lsst.daf.butler.DatasetRef`\ s.
51 Prameters
52 ---------
53 createInverse : bool
54 When adding a key associated with a producer or consumer, also create
55 and inverse mapping that allows looking up all the keys associated with
56 some value. Defaults to False.
57 """
59 def __init__(self, createInverse: bool = False):
60 self._producers: Dict[_T, _U] = {}
61 self._consumers: DefaultDict[_T, Set[_U]] = defaultdict(set)
62 self._createInverse = createInverse
63 if self._createInverse:
64 self._itemsDict: DefaultDict[_U, Set[_T]] = defaultdict(set)
66 def addProducer(self, key: _T, value: _U) -> None:
67 """Add a key which is produced by some value.
69 Parameters
70 ----------
71 key : TypeVar
72 The type to track
73 value : TypeVar
74 The type associated with the production of the key
76 Raises
77 ------
78 ValueError
79 Raised if key is already declared to be produced by another value
80 """
81 if (existing := self._producers.get(key)) is not None and existing != value:
82 raise ValueError(f"Only one node is allowed to produce {key}, the current producer is {existing}")
83 self._producers[key] = value
84 if self._createInverse:
85 self._itemsDict[value].add(key)
87 def removeProducer(self, key: _T, value: _U) -> None:
88 """Remove a value (e.g. QuantumNode or TaskDef) from being considered
89 a producer of the corresponding key. It is not an error to remove a
90 key that is not in the tracker.
92 Parameters
93 ----------
94 key : TypeVar
95 The type to track
96 value : TypeVar
97 The type associated with the production of the key
98 """
99 self._producers.pop(key, None)
100 if self._createInverse:
101 if result := self._itemsDict.get(value):
102 result.discard(key)
104 def addConsumer(self, key: _T, value: _U) -> None:
105 """Add a key which is consumed by some value.
107 Parameters
108 ----------
109 key : TypeVar
110 The type to track
111 value : TypeVar
112 The type associated with the consumption of the key
113 """
114 self._consumers[key].add(value)
115 if self._createInverse:
116 self._itemsDict[value].add(key)
118 def removeConsumer(self, key: _T, value: _U) -> None:
119 """Remove a value (e.g. QuantumNode or TaskDef) from being considered
120 a consumer of the corresponding key. It is not an error to remove a
121 key that is not in the tracker.
123 Parameters
124 ----------
125 key : TypeVar
126 The type to track
127 value : TypeVar
128 The type associated with the consumption of the key
129 """
130 if (result := self._consumers.get(key)) is not None:
131 result.discard(value)
132 if self._createInverse:
133 if result := self._itemsDict.get(value):
134 result.discard(key)
136 def getConsumers(self, key: _T) -> Set[_U]:
137 """Return all values associated with the consumption of the supplied
138 key.
140 Parameters
141 ----------
142 key : TypeVar
143 The type which has been tracked in the _DatasetTracker
144 """
145 return self._consumers.get(key, set())
147 def getProducer(self, key: _T) -> Optional[_U]:
148 """Return the value associated with the consumption of the supplied
149 key.
151 Parameters
152 ----------
153 key : TypeVar
154 The type which has been tracked in the _DatasetTracker
155 """
156 # This tracker may have had all nodes associated with a key removed
157 # and if there are no refs (empty set) should return None
158 return producer if (producer := self._producers.get(key)) else None
160 def getAll(self, key: _T) -> set[_U]:
161 """Return all consumers and the producer associated with the the
162 supplied key.
164 Parameters
165 ----------
166 key : TypeVar
167 The type which has been tracked in the _DatasetTracker
168 """
170 return self.getConsumers(key).union(x for x in (self.getProducer(key),) if x is not None)
172 @property
173 def inverse(self) -> Optional[DefaultDict[_U, Set[_T]]]:
174 """Return the inverse mapping if class was instantiated to create an
175 inverse, else return None.
176 """
177 return self._itemsDict if self._createInverse else None
179 def makeNetworkXGraph(self) -> nx.DiGraph:
180 """Create a NetworkX graph out of all the contained keys, using the
181 relations of producer and consumers to create the edges.
183 Returns:
184 graph : networkx.DiGraph
185 The graph created out of the supplied keys and their relations
186 """
187 graph = nx.DiGraph()
188 for entry in self._producers.keys() | self._consumers.keys():
189 producer = self.getProducer(entry)
190 consumers = self.getConsumers(entry)
191 # This block is for tasks that consume existing inputs
192 if producer is None and consumers:
193 for consumer in consumers:
194 graph.add_node(consumer)
195 # This block is for tasks that produce output that is not consumed
196 # in this graph
197 elif producer is not None and not consumers:
198 graph.add_node(producer)
199 # all other connections
200 else:
201 for consumer in consumers:
202 graph.add_edge(producer, consumer)
203 return graph
205 def keys(self) -> Set[_T]:
206 """Return all tracked keys."""
207 return self._producers.keys() | self._consumers.keys()
209 def remove(self, key: _T) -> None:
210 """Remove a key and its corresponding value from the tracker, this is
211 a no-op if the key is not in the tracker.
213 Parameters
214 ----------
215 key : TypeVar
216 A key tracked by the DatasetTracker
217 """
218 self._producers.pop(key, None)
219 self._consumers.pop(key, None)
221 def __contains__(self, key: _T) -> bool:
222 """Check if a key is in the _DatasetTracker
224 Parameters
225 ----------
226 key : TypeVar
227 The key to check
229 Returns
230 -------
231 contains : bool
232 Boolean of the presence of the supplied key
233 """
234 return key in self._producers or key in self._consumers
237def _pruner(
238 datasetRefDict: _DatasetTracker[DatasetRef, QuantumNode],
239 refsToRemove: Iterable[DatasetRef],
240 *,
241 alreadyPruned: Optional[Set[QuantumNode]] = None,
242) -> None:
243 r"""Prune supplied dataset refs out of datasetRefDict container, recursing
244 to additional nodes dependant on pruned refs. This function modifies
245 datasetRefDict in-place.
247 Parameters
248 ----------
249 datasetRefDict : `_DatasetTracker[DatasetRef, QuantumNode]`
250 The dataset tracker that maps `DatasetRef`\ s to the Quantum Nodes
251 that produce/consume that `DatasetRef`
252 refsToRemove : `Iterable` of `DatasetRef`
253 The `DatasetRef`\ s which should be pruned from the input dataset
254 tracker
255 alreadyPruned : `set` of `QuantumNode`
256 A set of nodes which have been pruned from the dataset tracker
257 """
258 if alreadyPruned is None:
259 alreadyPruned = set()
260 for ref in refsToRemove:
261 # make a copy here, because this structure will be modified in
262 # recursion, hitting a node more than once won't be much of an
263 # issue, as we skip anything that has been processed
264 nodes = set(datasetRefDict.getConsumers(ref))
265 for node in nodes:
266 # This node will never be associated with this ref
267 datasetRefDict.removeConsumer(ref, node)
268 if node in alreadyPruned:
269 continue
270 # find the connection corresponding to the input ref
271 connectionRefs = node.quantum.inputs.get(ref.datasetType)
272 if connectionRefs is None:
273 # look to see if any inputs are component refs that match the
274 # input ref to prune
275 others = ref.datasetType.makeAllComponentDatasetTypes()
276 # for each other component type check if there are assocated
277 # refs
278 for other in others:
279 connectionRefs = node.quantum.inputs.get(other)
280 if connectionRefs is not None:
281 # now search the component refs and see which one
282 # matches the ref to trim
283 for cr in connectionRefs:
284 if cr.makeCompositeRef() == ref:
285 toRemove = cr
286 break
287 else:
288 # Ref must be an initInput ref and we want to ignore those
289 raise RuntimeError(f"Cannot prune on non-Input dataset type {ref.datasetType.name}")
290 else:
291 toRemove = ref
293 tmpRefs = set(connectionRefs).difference((toRemove,))
294 tmpConnections = NamedKeyDict[DatasetType, List[DatasetRef]](node.quantum.inputs.items())
295 tmpConnections[toRemove.datasetType] = list(tmpRefs)
296 helper = AdjustQuantumHelper(inputs=tmpConnections, outputs=node.quantum.outputs)
297 assert node.quantum.dataId is not None, (
298 "assert to make the type checker happy, it should not "
299 "actually be possible to not have dataId set to None "
300 "at this point"
301 )
303 # Try to adjust the quantum with the reduced refs to make sure the
304 # node will still satisfy all its conditions.
305 #
306 # If it can't because NoWorkFound is raised, that means a
307 # connection is no longer present, and the node should be removed
308 # from the graph.
309 try:
310 helper.adjust_in_place(node.taskDef.connections, node.taskDef.label, node.quantum.dataId)
311 newQuantum = Quantum(
312 taskName=node.quantum.taskName,
313 taskClass=node.quantum.taskClass,
314 dataId=node.quantum.dataId,
315 initInputs=node.quantum.initInputs,
316 inputs=helper.inputs,
317 outputs=helper.outputs,
318 )
319 # If the inputs or outputs were adjusted to something different
320 # than what was supplied by the graph builder, dissassociate
321 # node from those refs, and if they are output refs, prune them
322 # from downstream tasks. This means that based on new inputs
323 # the task wants to produce fewer outputs, or consume fewer
324 # inputs.
325 for condition, existingMapping, newMapping, remover in (
326 (
327 helper.inputs_adjusted,
328 node.quantum.inputs,
329 helper.inputs,
330 datasetRefDict.removeConsumer,
331 ),
332 (
333 helper.outputs_adjusted,
334 node.quantum.outputs,
335 helper.outputs,
336 datasetRefDict.removeProducer,
337 ),
338 ):
339 if condition:
340 notNeeded = set()
341 for key in existingMapping:
342 if key not in newMapping:
343 compositeRefs = (
344 r if not r.isComponent() else r.makeCompositeRef()
345 for r in existingMapping[key]
346 )
347 notNeeded |= set(compositeRefs)
348 continue
349 notNeeded |= set(existingMapping[key]) - set(newMapping[key])
350 if notNeeded:
351 for ref in notNeeded:
352 if ref.isComponent():
353 ref = ref.makeCompositeRef()
354 remover(ref, node)
355 if remover is datasetRefDict.removeProducer:
356 _pruner(datasetRefDict, notNeeded, alreadyPruned=alreadyPruned)
357 object.__setattr__(node, "quantum", newQuantum)
358 noWorkFound = False
360 except NoWorkFound:
361 noWorkFound = True
363 if noWorkFound:
364 # This will throw if the length is less than the minimum number
365 for tmpRef in chain(
366 chain.from_iterable(node.quantum.inputs.values()), node.quantum.initInputs.values()
367 ):
368 if tmpRef.isComponent():
369 tmpRef = tmpRef.makeCompositeRef()
370 datasetRefDict.removeConsumer(tmpRef, node)
371 alreadyPruned.add(node)
372 # prune all outputs produced by this node
373 # mark that none of these will be produced
374 forwardPrunes = set()
375 for forwardRef in chain.from_iterable(node.quantum.outputs.values()):
376 datasetRefDict.removeProducer(forwardRef, node)
377 forwardPrunes.add(forwardRef)
378 _pruner(datasetRefDict, forwardPrunes, alreadyPruned=alreadyPruned)