Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22from collections import defaultdict 

23 

24__all__ = ("_DatasetTracker", "DatasetTypeName") 

25 

26from dataclasses import dataclass, field 

27import networkx as nx 

28from typing import (DefaultDict, Generic, Optional, Set, TypeVar, Generator, Tuple, NewType,) 

29 

30from lsst.daf.butler import DatasetRef 

31 

32from .quantumNode import QuantumNode 

33from ..pipeline import TaskDef 

34 

35# NewTypes 

36DatasetTypeName = NewType("DatasetTypeName", str) 

37 

38# Generic type parameters 

39_T = TypeVar("_T", DatasetTypeName, DatasetRef) 

40_U = TypeVar("_U", TaskDef, QuantumNode) 

41 

42 

43@dataclass 

44class _DatasetTrackerElement(Generic[_U]): 

45 inputs: Set[_U] = field(default_factory=set) 

46 output: Optional[_U] = None 

47 

48 

49class _DatasetTracker(Generic[_T, _U]): 

50 def __init__(self): 

51 self._container: DefaultDict[_T, _DatasetTrackerElement[_U]] = defaultdict(_DatasetTrackerElement) 

52 

53 def addInput(self, key: _T, value: _U): 

54 self._container[key].inputs.add(value) 

55 

56 def addOutput(self, key: _T, value: _U): 

57 element = self._container[key] 

58 if element.output is not None: 

59 raise ValueError(f"Only one output for key {key} is allowed, " 

60 f"the current output is set to {element.output}") 

61 element.output = value 

62 

63 def getInputs(self, key: _T) -> Set[_U]: 

64 return self._container[key].inputs 

65 

66 def getOutput(self, key: _T) -> Optional[_U]: 

67 return self._container[key].output 

68 

69 def getAll(self, key: _T) -> Set[_U]: 

70 output = self._container[key].output 

71 if output is not None: 

72 return self._container[key].inputs.union((output,)) 

73 return set(self._container[key].inputs) 

74 

75 def makeNetworkXGraph(self) -> nx.DiGraph: 

76 graph = nx.DiGraph() 

77 graph.add_edges_from(self._datasetDictToEdgeIterator()) 

78 if None in graph.nodes(): 

79 graph.remove_node(None) 

80 return graph 

81 

82 def _datasetDictToEdgeIterator(self) -> Generator[Tuple[Optional[_U], Optional[_U]], None, None]: 

83 """Helper function designed to be used in conjunction with 

84 `networkx.DiGraph.add_edges_from`. This takes a mapping of keys to 

85 `_DatasetTrackers` and yields successive pairs of elements that are to 

86 be considered connected by the graph. 

87 """ 

88 for entry in self._container.values(): 

89 # If there is no inputs and only outputs (likely in test cases or 

90 # building inits or something) use None as a Node, that will then 

91 # be removed later 

92 inputs = entry.inputs or (None,) 

93 for inpt in inputs: 

94 yield (entry.output, inpt) 

95 

96 def keys(self) -> Set[_T]: 

97 return set(self._container.keys())