Coverage for python/lsst/pipe/base/graph/_implDetails.py: 15%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

132 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("_DatasetTracker", "DatasetTypeName", "_pruner") 

24 

25from collections import defaultdict 

26 

27from itertools import chain 

28import networkx as nx 

29from typing import (DefaultDict, Generic, Optional, Set, TypeVar, NewType, Dict, Iterable) 

30 

31from lsst.daf.butler import DatasetRef, NamedKeyDict, Quantum 

32from lsst.pipe.base.connections import AdjustQuantumHelper 

33 

34from .quantumNode import QuantumNode 

35from ..pipeline import TaskDef 

36from .._status import NoWorkFound 

37 

38# NewTypes 

39DatasetTypeName = NewType("DatasetTypeName", str) 

40 

41# Generic type parameters 

42_T = TypeVar("_T", DatasetTypeName, DatasetRef) 

43_U = TypeVar("_U", TaskDef, QuantumNode) 

44 

45 

46class _DatasetTracker(Generic[_T, _U]): 

47 r"""This is a generic container for tracking keys which are produced or 

48 consumed by some value. In the context of a QuantumGraph, keys may be 

49 `~lsst.daf.butler.DatasetRef`\ s and the values would be Quanta that either 

50 produce or consume those `~lsst.daf.butler.DatasetRef`\ s. 

51 

52 Prameters 

53 --------- 

54 createInverse : bool 

55 When adding a key associated with a producer or consumer, also create 

56 and inverse mapping that allows looking up all the keys associated with 

57 some value. Defaults to False. 

58 """ 

59 def __init__(self, createInverse: bool = False): 

60 self._producers: Dict[_T, _U] = {} 

61 self._consumers: DefaultDict[_T, Set[_U]] = defaultdict(set) 

62 self._createInverse = createInverse 

63 if self._createInverse: 

64 self._itemsDict: DefaultDict[_U, Set[_T]] = defaultdict(set) 

65 

66 def addProducer(self, key: _T, value: _U): 

67 """Add a key which is produced by some value. 

68 

69 Parameters 

70 ---------- 

71 key : TypeVar 

72 The type to track 

73 value : TypeVar 

74 The type associated with the production of the key 

75 

76 Raises 

77 ------ 

78 ValueError 

79 Raised if key is already declared to be produced by another value 

80 """ 

81 if (existing := self._producers.get(key)) is not None and existing != value: 

82 raise ValueError(f"Only one node is allowed to produce {key}, " 

83 f"the current producer is {existing}") 

84 self._producers[key] = value 

85 if self._createInverse: 

86 self._itemsDict[value].add(key) 

87 

88 def removeProducer(self, key: _T, value: _U): 

89 """Remove a value (e.g. QuantumNode or TaskDef) from being considered 

90 a producer of the corresponding key. It is not an error to remove a 

91 key that is not in the tracker. 

92 

93 Parameters 

94 ---------- 

95 key : TypeVar 

96 The type to track 

97 value : TypeVar 

98 The type associated with the production of the key 

99 """ 

100 self._producers.pop(key, None) 

101 if self._createInverse: 

102 if (result := self._itemsDict.get(value)): 

103 result.discard(key) 

104 

105 def addConsumer(self, key: _T, value: _U): 

106 """Add a key which is consumed by some value. 

107 

108 Parameters 

109 ---------- 

110 key : TypeVar 

111 The type to track 

112 value : TypeVar 

113 The type associated with the consumption of the key 

114 """ 

115 self._consumers[key].add(value) 

116 if self._createInverse: 

117 self._itemsDict[value].add(key) 

118 

119 def removeConsumer(self, key: _T, value: _U): 

120 """Remove a value (e.g. QuantumNode or TaskDef) from being considered 

121 a consumer of the corresponding key. It is not an error to remove a 

122 key that is not in the tracker. 

123 

124 Parameters 

125 ---------- 

126 key : TypeVar 

127 The type to track 

128 value : TypeVar 

129 The type associated with the consumption of the key 

130 """ 

131 result = self._consumers.get(key) 

132 if (result := self._consumers.get(key)): 

133 result.discard(value) 

134 if self._createInverse: 

135 if (result := self._itemsDict.get(value)): 

136 result.discard(key) 

137 

138 def getConsumers(self, key: _T) -> set[_U]: 

139 """Return all values associated with the consumption of the supplied 

140 key. 

141 

142 Parameters 

143 ---------- 

144 key : TypeVar 

145 The type which has been tracked in the _DatasetTracker 

146 """ 

147 return self._consumers.get(key, set()) 

148 

149 def getProducer(self, key: _T) -> Optional[_U]: 

150 """Return the value associated with the consumption of the supplied 

151 key. 

152 

153 Parameters 

154 ---------- 

155 key : TypeVar 

156 The type which has been tracked in the _DatasetTracker 

157 """ 

158 # This tracker may have had all nodes associated with a key removed 

159 # and if there are no refs (empty set) should return None 

160 return producer if (producer := self._producers.get(key)) else None 

161 

162 def getAll(self, key: _T) -> set[_U]: 

163 """Return all consumers and the producer associated with the the 

164 supplied key. 

165 

166 Parameters 

167 ---------- 

168 key : TypeVar 

169 The type which has been tracked in the _DatasetTracker 

170 """ 

171 

172 return self.getConsumers(key).union(x for x in (self.getProducer(key),) if x is not None) 

173 

174 @property 

175 def inverse(self) -> Optional[DefaultDict[_U, Set[_T]]]: 

176 """Return the inverse mapping if class was instantiated to create an 

177 inverse, else return None. 

178 """ 

179 return self._itemsDict if self._createInverse else None 

180 

181 def makeNetworkXGraph(self) -> nx.DiGraph: 

182 """Create a NetworkX graph out of all the contained keys, using the 

183 relations of producer and consumers to create the edges. 

184 

185 Returns: 

186 graph : networkx.DiGraph 

187 The graph created out of the supplied keys and their relations 

188 """ 

189 graph = nx.DiGraph() 

190 for entry in self._producers.keys() | self._consumers.keys(): 

191 producer = self.getProducer(entry) 

192 consumers = self.getConsumers(entry) 

193 # This block is for tasks that consume existing inputs 

194 if producer is None and consumers: 

195 for consumer in consumers: 

196 graph.add_node(consumer) 

197 # This block is for tasks that produce output that is not consumed 

198 # in this graph 

199 elif producer is not None and not consumers: 

200 graph.add_node(producer) 

201 # all other connections 

202 else: 

203 for consumer in consumers: 

204 graph.add_edge(producer, consumer) 

205 return graph 

206 

207 def keys(self) -> Set[_T]: 

208 """Return all tracked keys. 

209 """ 

210 return self._producers.keys() | self._consumers.keys() 

211 

212 def remove(self, key: _T): 

213 """Remove a key and its corresponding value from the tracker, this is 

214 a no-op if the key is not in the tracker. 

215 

216 Parameters 

217 ---------- 

218 key : TypeVar 

219 A key tracked by the DatasetTracker 

220 """ 

221 self._producers.pop(key, None) 

222 self._consumers.pop(key, None) 

223 

224 def __contains__(self, key: _T) -> bool: 

225 """Check if a key is in the _DatasetTracker 

226 

227 Parameters 

228 ---------- 

229 key : TypeVar 

230 The key to check 

231 

232 Returns 

233 ------- 

234 contains : bool 

235 Boolean of the presence of the supplied key 

236 """ 

237 return key in self._producers or key in self._consumers 

238 

239 

240def _pruner(datasetRefDict: _DatasetTracker[DatasetRef, QuantumNode], refsToRemove: Iterable[DatasetRef], *, 

241 alreadyPruned: Optional[Set[QuantumNode]] = None): 

242 r"""Prune supplied dataset refs out of datasetRefDict container, recursing 

243 to additional nodes dependant on pruned refs. This function modifies 

244 datasetRefDict in-place. 

245 

246 Parameters 

247 ---------- 

248 datasetRefDict : `_DatasetTracker[DatasetRef, QuantumNode]` 

249 The dataset tracker that maps `DatasetRef`\ s to the Quantum Nodes 

250 that produce/consume that `DatasetRef` 

251 refsToRemove : `Iterable` of `DatasetRef` 

252 The `DatasetRef`\ s which should be pruned from the input dataset 

253 tracker 

254 alreadyPruned : `set` of `QuantumNode` 

255 A set of nodes which have been pruned from the dataset tracker 

256 """ 

257 if alreadyPruned is None: 

258 alreadyPruned = set() 

259 for ref in refsToRemove: 

260 # make a copy here, because this structure will be modified in 

261 # recursion, hitting a node more than once won't be much of an 

262 # issue, as we skip anything that has been processed 

263 nodes = set(datasetRefDict.getConsumers(ref)) 

264 for node in nodes: 

265 # This node will never be associated with this ref 

266 datasetRefDict.removeConsumer(ref, node) 

267 if node in alreadyPruned: 

268 continue 

269 # find the connection corresponding to the input ref 

270 connectionRefs = node.quantum.inputs.get(ref.datasetType) 

271 if connectionRefs is None: 

272 # look to see if any inputs are component refs that match the 

273 # input ref to prune 

274 others = ref.datasetType.makeAllComponentDatasetTypes() 

275 # for each other component type check if there are assocated 

276 # refs 

277 for other in others: 

278 connectionRefs = node.quantum.inputs.get(other) 

279 if connectionRefs is not None: 

280 # now search the component refs and see which one 

281 # matches the ref to trim 

282 for cr in connectionRefs: 

283 if cr.makeCompositeRef() == ref: 

284 toRemove = cr 

285 break 

286 else: 

287 # Ref must be an initInput ref and we want to ignore those 

288 raise RuntimeError(f"Cannot prune on non-Input dataset type {ref.datasetType.name}") 

289 else: 

290 toRemove = ref 

291 

292 tmpRefs = set(connectionRefs).difference((toRemove,)) 

293 tmpConnections = NamedKeyDict(node.quantum.inputs.items()) 

294 tmpConnections[toRemove.datasetType] = list(tmpRefs) 

295 helper = AdjustQuantumHelper(inputs=tmpConnections, outputs=node.quantum.outputs) 

296 assert node.quantum.dataId is not None, ("assert to make the type checker happy, it should not " 

297 "actually be possible to not have dataId set to None " 

298 "at this point") 

299 

300 # Try to adjust the quantum with the reduced refs to make sure the 

301 # node will still satisfy all its conditions. 

302 # 

303 # If it can't because NoWorkFound is raised, that means a 

304 # connection is no longer present, and the node should be removed 

305 # from the graph. 

306 try: 

307 helper.adjust_in_place(node.taskDef.connections, node.taskDef.label, node.quantum.dataId) 

308 newQuantum = Quantum(taskName=node.quantum.taskName, taskClass=node.quantum.taskClass, 

309 dataId=node.quantum.dataId, initInputs=node.quantum.initInputs, 

310 inputs=helper.inputs, outputs=helper.outputs) 

311 # If the inputs or outputs were adjusted to something different 

312 # than what was supplied by the graph builder, dissassociate 

313 # node from those refs, and if they are output refs, prune them 

314 # from downstream tasks. This means that based on new inputs 

315 # the task wants to produce fewer outputs, or consume fewer 

316 # inputs. 

317 for condition, existingMapping, newMapping, remover in ( 

318 (helper.inputs_adjusted, node.quantum.inputs, helper.inputs, 

319 datasetRefDict.removeConsumer), 

320 (helper.outputs_adjusted, node.quantum.outputs, helper.outputs, 

321 datasetRefDict.removeProducer)): 

322 if condition: 

323 notNeeded = set() 

324 for key in existingMapping: 

325 if key not in newMapping: 

326 compositeRefs = (r if not r.isComponent() else r.makeCompositeRef() 

327 for r in existingMapping[key]) 

328 notNeeded |= set(compositeRefs) 

329 continue 

330 notNeeded |= set(existingMapping[key]) - set(newMapping[key]) 

331 if notNeeded: 

332 for ref in notNeeded: 

333 if ref.isComponent(): 

334 ref = ref.makeCompositeRef() 

335 remover(ref, node) 

336 if remover is datasetRefDict.removeProducer: 

337 _pruner(datasetRefDict, notNeeded, alreadyPruned=alreadyPruned) 

338 object.__setattr__(node, 'quantum', newQuantum) 

339 noWorkFound = False 

340 

341 except NoWorkFound: 

342 noWorkFound = True 

343 

344 if noWorkFound: 

345 # This will throw if the length is less than the minimum number 

346 for tmpRef in chain(chain.from_iterable(node.quantum.inputs.values()), 

347 node.quantum.initInputs.values()): 

348 if tmpRef.isComponent(): 

349 tmpRef = tmpRef.makeCompositeRef() 

350 datasetRefDict.removeConsumer(tmpRef, node) 

351 alreadyPruned.add(node) 

352 # prune all outputs produced by this node 

353 # mark that none of these will be produced 

354 forwardPrunes = set() 

355 for forwardRef in chain.from_iterable(node.quantum.outputs.values()): 

356 datasetRefDict.removeProducer(forwardRef, node) 

357 forwardPrunes.add(forwardRef) 

358 _pruner(datasetRefDict, forwardPrunes, alreadyPruned=alreadyPruned)