Coverage for python/lsst/pipe/base/graph/quantumNode.py: 68%

62 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-11 03:31 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ("QuantumNode", "NodeId", "BuildId") 

30 

31import uuid 

32from dataclasses import dataclass 

33from typing import Any, NewType 

34 

35import pydantic 

36from lsst.daf.butler import ( 

37 DatasetRef, 

38 DimensionRecordsAccumulator, 

39 DimensionUniverse, 

40 Quantum, 

41 SerializedQuantum, 

42) 

43 

44from ..pipeline import TaskDef 

45from ..pipeline_graph import PipelineGraph, TaskNode 

46 

47BuildId = NewType("BuildId", str) 

48 

49 

50def _hashDsRef(ref: DatasetRef) -> int: 

51 return hash((ref.datasetType, ref.dataId)) 

52 

53 

54@dataclass(frozen=True, eq=True) 

55class NodeId: 

56 """Deprecated, this class is used with QuantumGraph save formats of 

57 1 and 2 when unpicking objects and must be retained until those formats 

58 are considered unloadable. 

59 

60 This represents an unique identifier of a node within an individual 

61 construction of a `QuantumGraph`. This identifier will stay constant 

62 through a pickle, and any `QuantumGraph` methods that return a new 

63 `QuantumGraph`. 

64 

65 A `NodeId` will not be the same if a new graph is built containing the same 

66 information in a `QuantumNode`, or even built from exactly the same inputs. 

67 

68 `NodeId`s do not play any role in deciding the equality or identity (hash) 

69 of a `QuantumNode`, and are mainly useful in debugging or working with 

70 various subsets of the same graph. 

71 

72 This interface is a convenance only, and no guarantees on long term 

73 stability are made. New implementations might change the `NodeId`, or 

74 provide more or less guarantees. 

75 """ 

76 

77 number: int 

78 """The unique position of the node within the graph assigned at graph 

79 creation. 

80 """ 

81 buildId: BuildId 

82 """Unique identifier created at the time the originating graph was created 

83 """ 

84 

85 

86@dataclass(frozen=True) 

87class QuantumNode: 

88 """Class representing a node in the quantum graph. 

89 

90 The ``quantum`` attribute represents the data that is to be processed at 

91 this node. 

92 """ 

93 

94 quantum: Quantum 

95 """The unit of data that is to be processed by this graph node""" 

96 taskDef: TaskDef 

97 """Definition of the task that will process the `Quantum` associated with 

98 this node. 

99 """ 

100 nodeId: uuid.UUID 

101 """The unique position of the node within the graph assigned at graph 

102 creation. 

103 """ 

104 

105 @property 

106 def task_node(self) -> TaskNode: 

107 """Return the node object that represents this task in a pipeline 

108 graph. 

109 """ 

110 pipeline_graph = PipelineGraph() 

111 return pipeline_graph.add_task( 

112 self.taskDef.label, 

113 self.taskDef.taskClass, 

114 self.taskDef.config, 

115 connections=self.taskDef.connections, 

116 ) 

117 

118 __slots__ = ("quantum", "taskDef", "nodeId", "_precomputedHash") 

119 

120 def __post_init__(self) -> None: 

121 # use setattr here to preserve the frozenness of the QuantumNode 

122 self._precomputedHash: int 

123 object.__setattr__(self, "_precomputedHash", hash((self.taskDef.label, self.quantum))) 

124 

125 def __eq__(self, other: object) -> bool: 

126 if not isinstance(other, QuantumNode): 

127 return False 

128 if self.quantum != other.quantum: 

129 return False 

130 return self.taskDef == other.taskDef 

131 

132 def __hash__(self) -> int: 

133 """For graphs it is useful to have a more robust hash than provided 

134 by the default quantum id based hashing 

135 """ 

136 return self._precomputedHash 

137 

138 def __repr__(self) -> str: 

139 """Make more human readable string representation.""" 

140 return ( 

141 f"{self.__class__.__name__}(quantum={self.quantum}, taskDef={self.taskDef}, nodeId={self.nodeId})" 

142 ) 

143 

144 def to_simple(self, accumulator: DimensionRecordsAccumulator | None = None) -> SerializedQuantumNode: 

145 return SerializedQuantumNode( 

146 quantum=self.quantum.to_simple(accumulator=accumulator), 

147 taskLabel=self.taskDef.label, 

148 nodeId=self.nodeId, 

149 ) 

150 

151 @classmethod 

152 def from_simple( 

153 cls, 

154 simple: SerializedQuantumNode, 

155 taskDefMap: dict[str, TaskDef], 

156 universe: DimensionUniverse, 

157 ) -> QuantumNode: 

158 return QuantumNode( 

159 quantum=Quantum.from_simple(simple.quantum, universe), 

160 taskDef=taskDefMap[simple.taskLabel], 

161 nodeId=simple.nodeId, 

162 ) 

163 

164 def _replace_quantum(self, quantum: Quantum) -> None: 

165 """Replace Quantum instance in this node. 

166 

167 Parameters 

168 ---------- 

169 quantum : `Quantum` 

170 New Quantum instance for this node. 

171 

172 Raises 

173 ------ 

174 ValueError 

175 Raised if the hash of the new quantum is different from the hash of 

176 the existing quantum. 

177 

178 Notes 

179 ----- 

180 This class is immutable and hashable, so this method checks that new 

181 quantum does not invalidate its current hash. This method is supposed 

182 to used only by `QuantumGraph` class as its implementation detail, 

183 so it is made "underscore-protected". 

184 """ 

185 if hash(quantum) != hash(self.quantum): 

186 raise ValueError( 

187 f"Hash of the new quantum {quantum} does not match hash of existing quantum {self.quantum}" 

188 ) 

189 object.__setattr__(self, "quantum", quantum) 

190 

191 

192_fields_set = {"quantum", "taskLabel", "nodeId"} 

193 

194 

195class SerializedQuantumNode(pydantic.BaseModel): 

196 """Model representing a `QuantumNode` in serializable form.""" 

197 

198 quantum: SerializedQuantum 

199 taskLabel: str 

200 nodeId: uuid.UUID 

201 

202 @classmethod 

203 def direct(cls, *, quantum: dict[str, Any], taskLabel: str, nodeId: str) -> SerializedQuantumNode: 

204 node = cls.model_construct( 

205 __fields_set=_fields_set, 

206 quantum=SerializedQuantum.direct(**quantum), 

207 taskLabel=taskLabel, 

208 nodeId=uuid.UUID(nodeId), 

209 ) 

210 

211 return node