Coverage for python/lsst/pipe/base/graph/_implDetails.py: 15%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

132 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("_DatasetTracker", "DatasetTypeName", "_pruner") 

24 

25from collections import defaultdict 

26from itertools import chain 

27from typing import DefaultDict, Dict, Generic, Iterable, NewType, Optional, Set, TypeVar 

28 

29import networkx as nx 

30from lsst.daf.butler import DatasetRef, NamedKeyDict, Quantum 

31from lsst.pipe.base.connections import AdjustQuantumHelper 

32 

33from .._status import NoWorkFound 

34from ..pipeline import TaskDef 

35from .quantumNode import QuantumNode 

36 

37# NewTypes 

38DatasetTypeName = NewType("DatasetTypeName", str) 

39 

40# Generic type parameters 

41_T = TypeVar("_T", DatasetTypeName, DatasetRef) 

42_U = TypeVar("_U", TaskDef, QuantumNode) 

43 

44 

45class _DatasetTracker(Generic[_T, _U]): 

46 r"""This is a generic container for tracking keys which are produced or 

47 consumed by some value. In the context of a QuantumGraph, keys may be 

48 `~lsst.daf.butler.DatasetRef`\ s and the values would be Quanta that either 

49 produce or consume those `~lsst.daf.butler.DatasetRef`\ s. 

50 

51 Prameters 

52 --------- 

53 createInverse : bool 

54 When adding a key associated with a producer or consumer, also create 

55 and inverse mapping that allows looking up all the keys associated with 

56 some value. Defaults to False. 

57 """ 

58 

59 def __init__(self, createInverse: bool = False): 

60 self._producers: Dict[_T, _U] = {} 

61 self._consumers: DefaultDict[_T, Set[_U]] = defaultdict(set) 

62 self._createInverse = createInverse 

63 if self._createInverse: 

64 self._itemsDict: DefaultDict[_U, Set[_T]] = defaultdict(set) 

65 

66 def addProducer(self, key: _T, value: _U): 

67 """Add a key which is produced by some value. 

68 

69 Parameters 

70 ---------- 

71 key : TypeVar 

72 The type to track 

73 value : TypeVar 

74 The type associated with the production of the key 

75 

76 Raises 

77 ------ 

78 ValueError 

79 Raised if key is already declared to be produced by another value 

80 """ 

81 if (existing := self._producers.get(key)) is not None and existing != value: 

82 raise ValueError(f"Only one node is allowed to produce {key}, the current producer is {existing}") 

83 self._producers[key] = value 

84 if self._createInverse: 

85 self._itemsDict[value].add(key) 

86 

87 def removeProducer(self, key: _T, value: _U): 

88 """Remove a value (e.g. QuantumNode or TaskDef) from being considered 

89 a producer of the corresponding key. It is not an error to remove a 

90 key that is not in the tracker. 

91 

92 Parameters 

93 ---------- 

94 key : TypeVar 

95 The type to track 

96 value : TypeVar 

97 The type associated with the production of the key 

98 """ 

99 self._producers.pop(key, None) 

100 if self._createInverse: 

101 if result := self._itemsDict.get(value): 

102 result.discard(key) 

103 

104 def addConsumer(self, key: _T, value: _U): 

105 """Add a key which is consumed by some value. 

106 

107 Parameters 

108 ---------- 

109 key : TypeVar 

110 The type to track 

111 value : TypeVar 

112 The type associated with the consumption of the key 

113 """ 

114 self._consumers[key].add(value) 

115 if self._createInverse: 

116 self._itemsDict[value].add(key) 

117 

118 def removeConsumer(self, key: _T, value: _U): 

119 """Remove a value (e.g. QuantumNode or TaskDef) from being considered 

120 a consumer of the corresponding key. It is not an error to remove a 

121 key that is not in the tracker. 

122 

123 Parameters 

124 ---------- 

125 key : TypeVar 

126 The type to track 

127 value : TypeVar 

128 The type associated with the consumption of the key 

129 """ 

130 result = self._consumers.get(key) 

131 if result := self._consumers.get(key): 

132 result.discard(value) 

133 if self._createInverse: 

134 if result := self._itemsDict.get(value): 

135 result.discard(key) 

136 

137 def getConsumers(self, key: _T) -> set[_U]: 

138 """Return all values associated with the consumption of the supplied 

139 key. 

140 

141 Parameters 

142 ---------- 

143 key : TypeVar 

144 The type which has been tracked in the _DatasetTracker 

145 """ 

146 return self._consumers.get(key, set()) 

147 

148 def getProducer(self, key: _T) -> Optional[_U]: 

149 """Return the value associated with the consumption of the supplied 

150 key. 

151 

152 Parameters 

153 ---------- 

154 key : TypeVar 

155 The type which has been tracked in the _DatasetTracker 

156 """ 

157 # This tracker may have had all nodes associated with a key removed 

158 # and if there are no refs (empty set) should return None 

159 return producer if (producer := self._producers.get(key)) else None 

160 

161 def getAll(self, key: _T) -> set[_U]: 

162 """Return all consumers and the producer associated with the the 

163 supplied key. 

164 

165 Parameters 

166 ---------- 

167 key : TypeVar 

168 The type which has been tracked in the _DatasetTracker 

169 """ 

170 

171 return self.getConsumers(key).union(x for x in (self.getProducer(key),) if x is not None) 

172 

173 @property 

174 def inverse(self) -> Optional[DefaultDict[_U, Set[_T]]]: 

175 """Return the inverse mapping if class was instantiated to create an 

176 inverse, else return None. 

177 """ 

178 return self._itemsDict if self._createInverse else None 

179 

180 def makeNetworkXGraph(self) -> nx.DiGraph: 

181 """Create a NetworkX graph out of all the contained keys, using the 

182 relations of producer and consumers to create the edges. 

183 

184 Returns: 

185 graph : networkx.DiGraph 

186 The graph created out of the supplied keys and their relations 

187 """ 

188 graph = nx.DiGraph() 

189 for entry in self._producers.keys() | self._consumers.keys(): 

190 producer = self.getProducer(entry) 

191 consumers = self.getConsumers(entry) 

192 # This block is for tasks that consume existing inputs 

193 if producer is None and consumers: 

194 for consumer in consumers: 

195 graph.add_node(consumer) 

196 # This block is for tasks that produce output that is not consumed 

197 # in this graph 

198 elif producer is not None and not consumers: 

199 graph.add_node(producer) 

200 # all other connections 

201 else: 

202 for consumer in consumers: 

203 graph.add_edge(producer, consumer) 

204 return graph 

205 

206 def keys(self) -> Set[_T]: 

207 """Return all tracked keys.""" 

208 return self._producers.keys() | self._consumers.keys() 

209 

210 def remove(self, key: _T): 

211 """Remove a key and its corresponding value from the tracker, this is 

212 a no-op if the key is not in the tracker. 

213 

214 Parameters 

215 ---------- 

216 key : TypeVar 

217 A key tracked by the DatasetTracker 

218 """ 

219 self._producers.pop(key, None) 

220 self._consumers.pop(key, None) 

221 

222 def __contains__(self, key: _T) -> bool: 

223 """Check if a key is in the _DatasetTracker 

224 

225 Parameters 

226 ---------- 

227 key : TypeVar 

228 The key to check 

229 

230 Returns 

231 ------- 

232 contains : bool 

233 Boolean of the presence of the supplied key 

234 """ 

235 return key in self._producers or key in self._consumers 

236 

237 

238def _pruner( 

239 datasetRefDict: _DatasetTracker[DatasetRef, QuantumNode], 

240 refsToRemove: Iterable[DatasetRef], 

241 *, 

242 alreadyPruned: Optional[Set[QuantumNode]] = None, 

243): 

244 r"""Prune supplied dataset refs out of datasetRefDict container, recursing 

245 to additional nodes dependant on pruned refs. This function modifies 

246 datasetRefDict in-place. 

247 

248 Parameters 

249 ---------- 

250 datasetRefDict : `_DatasetTracker[DatasetRef, QuantumNode]` 

251 The dataset tracker that maps `DatasetRef`\ s to the Quantum Nodes 

252 that produce/consume that `DatasetRef` 

253 refsToRemove : `Iterable` of `DatasetRef` 

254 The `DatasetRef`\ s which should be pruned from the input dataset 

255 tracker 

256 alreadyPruned : `set` of `QuantumNode` 

257 A set of nodes which have been pruned from the dataset tracker 

258 """ 

259 if alreadyPruned is None: 

260 alreadyPruned = set() 

261 for ref in refsToRemove: 

262 # make a copy here, because this structure will be modified in 

263 # recursion, hitting a node more than once won't be much of an 

264 # issue, as we skip anything that has been processed 

265 nodes = set(datasetRefDict.getConsumers(ref)) 

266 for node in nodes: 

267 # This node will never be associated with this ref 

268 datasetRefDict.removeConsumer(ref, node) 

269 if node in alreadyPruned: 

270 continue 

271 # find the connection corresponding to the input ref 

272 connectionRefs = node.quantum.inputs.get(ref.datasetType) 

273 if connectionRefs is None: 

274 # look to see if any inputs are component refs that match the 

275 # input ref to prune 

276 others = ref.datasetType.makeAllComponentDatasetTypes() 

277 # for each other component type check if there are assocated 

278 # refs 

279 for other in others: 

280 connectionRefs = node.quantum.inputs.get(other) 

281 if connectionRefs is not None: 

282 # now search the component refs and see which one 

283 # matches the ref to trim 

284 for cr in connectionRefs: 

285 if cr.makeCompositeRef() == ref: 

286 toRemove = cr 

287 break 

288 else: 

289 # Ref must be an initInput ref and we want to ignore those 

290 raise RuntimeError(f"Cannot prune on non-Input dataset type {ref.datasetType.name}") 

291 else: 

292 toRemove = ref 

293 

294 tmpRefs = set(connectionRefs).difference((toRemove,)) 

295 tmpConnections = NamedKeyDict(node.quantum.inputs.items()) 

296 tmpConnections[toRemove.datasetType] = list(tmpRefs) 

297 helper = AdjustQuantumHelper(inputs=tmpConnections, outputs=node.quantum.outputs) 

298 assert node.quantum.dataId is not None, ( 

299 "assert to make the type checker happy, it should not " 

300 "actually be possible to not have dataId set to None " 

301 "at this point" 

302 ) 

303 

304 # Try to adjust the quantum with the reduced refs to make sure the 

305 # node will still satisfy all its conditions. 

306 # 

307 # If it can't because NoWorkFound is raised, that means a 

308 # connection is no longer present, and the node should be removed 

309 # from the graph. 

310 try: 

311 helper.adjust_in_place(node.taskDef.connections, node.taskDef.label, node.quantum.dataId) 

312 newQuantum = Quantum( 

313 taskName=node.quantum.taskName, 

314 taskClass=node.quantum.taskClass, 

315 dataId=node.quantum.dataId, 

316 initInputs=node.quantum.initInputs, 

317 inputs=helper.inputs, 

318 outputs=helper.outputs, 

319 ) 

320 # If the inputs or outputs were adjusted to something different 

321 # than what was supplied by the graph builder, dissassociate 

322 # node from those refs, and if they are output refs, prune them 

323 # from downstream tasks. This means that based on new inputs 

324 # the task wants to produce fewer outputs, or consume fewer 

325 # inputs. 

326 for condition, existingMapping, newMapping, remover in ( 

327 ( 

328 helper.inputs_adjusted, 

329 node.quantum.inputs, 

330 helper.inputs, 

331 datasetRefDict.removeConsumer, 

332 ), 

333 ( 

334 helper.outputs_adjusted, 

335 node.quantum.outputs, 

336 helper.outputs, 

337 datasetRefDict.removeProducer, 

338 ), 

339 ): 

340 if condition: 

341 notNeeded = set() 

342 for key in existingMapping: 

343 if key not in newMapping: 

344 compositeRefs = ( 

345 r if not r.isComponent() else r.makeCompositeRef() 

346 for r in existingMapping[key] 

347 ) 

348 notNeeded |= set(compositeRefs) 

349 continue 

350 notNeeded |= set(existingMapping[key]) - set(newMapping[key]) 

351 if notNeeded: 

352 for ref in notNeeded: 

353 if ref.isComponent(): 

354 ref = ref.makeCompositeRef() 

355 remover(ref, node) 

356 if remover is datasetRefDict.removeProducer: 

357 _pruner(datasetRefDict, notNeeded, alreadyPruned=alreadyPruned) 

358 object.__setattr__(node, "quantum", newQuantum) 

359 noWorkFound = False 

360 

361 except NoWorkFound: 

362 noWorkFound = True 

363 

364 if noWorkFound: 

365 # This will throw if the length is less than the minimum number 

366 for tmpRef in chain( 

367 chain.from_iterable(node.quantum.inputs.values()), node.quantum.initInputs.values() 

368 ): 

369 if tmpRef.isComponent(): 

370 tmpRef = tmpRef.makeCompositeRef() 

371 datasetRefDict.removeConsumer(tmpRef, node) 

372 alreadyPruned.add(node) 

373 # prune all outputs produced by this node 

374 # mark that none of these will be produced 

375 forwardPrunes = set() 

376 for forwardRef in chain.from_iterable(node.quantum.outputs.values()): 

377 datasetRefDict.removeProducer(forwardRef, node) 

378 forwardPrunes.add(forwardRef) 

379 _pruner(datasetRefDict, forwardPrunes, alreadyPruned=alreadyPruned)