Coverage for python/lsst/pipe/base/graph/_implDetails.py: 15%

131 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-08-20 08:52 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("_DatasetTracker", "DatasetTypeName", "_pruner") 

24 

25from collections import defaultdict 

26from itertools import chain 

27from typing import DefaultDict, Dict, Generic, Iterable, List, NewType, Optional, Set, TypeVar 

28 

29import networkx as nx 

30from lsst.daf.butler import DatasetRef, DatasetType, NamedKeyDict, Quantum 

31from lsst.pipe.base.connections import AdjustQuantumHelper 

32 

33from .._status import NoWorkFound 

34from ..pipeline import TaskDef 

35from .quantumNode import QuantumNode 

36 

37# NewTypes 

38DatasetTypeName = NewType("DatasetTypeName", str) 

39 

40# Generic type parameters 

41_T = TypeVar("_T", DatasetTypeName, DatasetRef) 

42_U = TypeVar("_U", TaskDef, QuantumNode) 

43 

44 

45class _DatasetTracker(Generic[_T, _U]): 

46 r"""This is a generic container for tracking keys which are produced or 

47 consumed by some value. In the context of a QuantumGraph, keys may be 

48 `~lsst.daf.butler.DatasetRef`\ s and the values would be Quanta that either 

49 produce or consume those `~lsst.daf.butler.DatasetRef`\ s. 

50 

51 Prameters 

52 --------- 

53 createInverse : bool 

54 When adding a key associated with a producer or consumer, also create 

55 and inverse mapping that allows looking up all the keys associated with 

56 some value. Defaults to False. 

57 """ 

58 

59 def __init__(self, createInverse: bool = False): 

60 self._producers: Dict[_T, _U] = {} 

61 self._consumers: DefaultDict[_T, Set[_U]] = defaultdict(set) 

62 self._createInverse = createInverse 

63 if self._createInverse: 

64 self._itemsDict: DefaultDict[_U, Set[_T]] = defaultdict(set) 

65 

66 def addProducer(self, key: _T, value: _U) -> None: 

67 """Add a key which is produced by some value. 

68 

69 Parameters 

70 ---------- 

71 key : TypeVar 

72 The type to track 

73 value : TypeVar 

74 The type associated with the production of the key 

75 

76 Raises 

77 ------ 

78 ValueError 

79 Raised if key is already declared to be produced by another value 

80 """ 

81 if (existing := self._producers.get(key)) is not None and existing != value: 

82 raise ValueError(f"Only one node is allowed to produce {key}, the current producer is {existing}") 

83 self._producers[key] = value 

84 if self._createInverse: 

85 self._itemsDict[value].add(key) 

86 

87 def removeProducer(self, key: _T, value: _U) -> None: 

88 """Remove a value (e.g. QuantumNode or TaskDef) from being considered 

89 a producer of the corresponding key. It is not an error to remove a 

90 key that is not in the tracker. 

91 

92 Parameters 

93 ---------- 

94 key : TypeVar 

95 The type to track 

96 value : TypeVar 

97 The type associated with the production of the key 

98 """ 

99 self._producers.pop(key, None) 

100 if self._createInverse: 

101 if result := self._itemsDict.get(value): 

102 result.discard(key) 

103 

104 def addConsumer(self, key: _T, value: _U) -> None: 

105 """Add a key which is consumed by some value. 

106 

107 Parameters 

108 ---------- 

109 key : TypeVar 

110 The type to track 

111 value : TypeVar 

112 The type associated with the consumption of the key 

113 """ 

114 self._consumers[key].add(value) 

115 if self._createInverse: 

116 self._itemsDict[value].add(key) 

117 

118 def removeConsumer(self, key: _T, value: _U) -> None: 

119 """Remove a value (e.g. QuantumNode or TaskDef) from being considered 

120 a consumer of the corresponding key. It is not an error to remove a 

121 key that is not in the tracker. 

122 

123 Parameters 

124 ---------- 

125 key : TypeVar 

126 The type to track 

127 value : TypeVar 

128 The type associated with the consumption of the key 

129 """ 

130 if (result := self._consumers.get(key)) is not None: 

131 result.discard(value) 

132 if self._createInverse: 

133 if result := self._itemsDict.get(value): 

134 result.discard(key) 

135 

136 def getConsumers(self, key: _T) -> Set[_U]: 

137 """Return all values associated with the consumption of the supplied 

138 key. 

139 

140 Parameters 

141 ---------- 

142 key : TypeVar 

143 The type which has been tracked in the _DatasetTracker 

144 """ 

145 return self._consumers.get(key, set()) 

146 

147 def getProducer(self, key: _T) -> Optional[_U]: 

148 """Return the value associated with the consumption of the supplied 

149 key. 

150 

151 Parameters 

152 ---------- 

153 key : TypeVar 

154 The type which has been tracked in the _DatasetTracker 

155 """ 

156 # This tracker may have had all nodes associated with a key removed 

157 # and if there are no refs (empty set) should return None 

158 return producer if (producer := self._producers.get(key)) else None 

159 

160 def getAll(self, key: _T) -> set[_U]: 

161 """Return all consumers and the producer associated with the the 

162 supplied key. 

163 

164 Parameters 

165 ---------- 

166 key : TypeVar 

167 The type which has been tracked in the _DatasetTracker 

168 """ 

169 

170 return self.getConsumers(key).union(x for x in (self.getProducer(key),) if x is not None) 

171 

172 @property 

173 def inverse(self) -> Optional[DefaultDict[_U, Set[_T]]]: 

174 """Return the inverse mapping if class was instantiated to create an 

175 inverse, else return None. 

176 """ 

177 return self._itemsDict if self._createInverse else None 

178 

179 def makeNetworkXGraph(self) -> nx.DiGraph: 

180 """Create a NetworkX graph out of all the contained keys, using the 

181 relations of producer and consumers to create the edges. 

182 

183 Returns: 

184 graph : networkx.DiGraph 

185 The graph created out of the supplied keys and their relations 

186 """ 

187 graph = nx.DiGraph() 

188 for entry in self._producers.keys() | self._consumers.keys(): 

189 producer = self.getProducer(entry) 

190 consumers = self.getConsumers(entry) 

191 # This block is for tasks that consume existing inputs 

192 if producer is None and consumers: 

193 for consumer in consumers: 

194 graph.add_node(consumer) 

195 # This block is for tasks that produce output that is not consumed 

196 # in this graph 

197 elif producer is not None and not consumers: 

198 graph.add_node(producer) 

199 # all other connections 

200 else: 

201 for consumer in consumers: 

202 graph.add_edge(producer, consumer) 

203 return graph 

204 

205 def keys(self) -> Set[_T]: 

206 """Return all tracked keys.""" 

207 return self._producers.keys() | self._consumers.keys() 

208 

209 def remove(self, key: _T) -> None: 

210 """Remove a key and its corresponding value from the tracker, this is 

211 a no-op if the key is not in the tracker. 

212 

213 Parameters 

214 ---------- 

215 key : TypeVar 

216 A key tracked by the DatasetTracker 

217 """ 

218 self._producers.pop(key, None) 

219 self._consumers.pop(key, None) 

220 

221 def __contains__(self, key: _T) -> bool: 

222 """Check if a key is in the _DatasetTracker 

223 

224 Parameters 

225 ---------- 

226 key : TypeVar 

227 The key to check 

228 

229 Returns 

230 ------- 

231 contains : bool 

232 Boolean of the presence of the supplied key 

233 """ 

234 return key in self._producers or key in self._consumers 

235 

236 

237def _pruner( 

238 datasetRefDict: _DatasetTracker[DatasetRef, QuantumNode], 

239 refsToRemove: Iterable[DatasetRef], 

240 *, 

241 alreadyPruned: Optional[Set[QuantumNode]] = None, 

242) -> None: 

243 r"""Prune supplied dataset refs out of datasetRefDict container, recursing 

244 to additional nodes dependant on pruned refs. This function modifies 

245 datasetRefDict in-place. 

246 

247 Parameters 

248 ---------- 

249 datasetRefDict : `_DatasetTracker[DatasetRef, QuantumNode]` 

250 The dataset tracker that maps `DatasetRef`\ s to the Quantum Nodes 

251 that produce/consume that `DatasetRef` 

252 refsToRemove : `Iterable` of `DatasetRef` 

253 The `DatasetRef`\ s which should be pruned from the input dataset 

254 tracker 

255 alreadyPruned : `set` of `QuantumNode` 

256 A set of nodes which have been pruned from the dataset tracker 

257 """ 

258 if alreadyPruned is None: 

259 alreadyPruned = set() 

260 for ref in refsToRemove: 

261 # make a copy here, because this structure will be modified in 

262 # recursion, hitting a node more than once won't be much of an 

263 # issue, as we skip anything that has been processed 

264 nodes = set(datasetRefDict.getConsumers(ref)) 

265 for node in nodes: 

266 # This node will never be associated with this ref 

267 datasetRefDict.removeConsumer(ref, node) 

268 if node in alreadyPruned: 

269 continue 

270 # find the connection corresponding to the input ref 

271 connectionRefs = node.quantum.inputs.get(ref.datasetType) 

272 if connectionRefs is None: 

273 # look to see if any inputs are component refs that match the 

274 # input ref to prune 

275 others = ref.datasetType.makeAllComponentDatasetTypes() 

276 # for each other component type check if there are assocated 

277 # refs 

278 for other in others: 

279 connectionRefs = node.quantum.inputs.get(other) 

280 if connectionRefs is not None: 

281 # now search the component refs and see which one 

282 # matches the ref to trim 

283 for cr in connectionRefs: 

284 if cr.makeCompositeRef() == ref: 

285 toRemove = cr 

286 break 

287 else: 

288 # Ref must be an initInput ref and we want to ignore those 

289 raise RuntimeError(f"Cannot prune on non-Input dataset type {ref.datasetType.name}") 

290 else: 

291 toRemove = ref 

292 

293 tmpRefs = set(connectionRefs).difference((toRemove,)) 

294 tmpConnections = NamedKeyDict[DatasetType, List[DatasetRef]](node.quantum.inputs.items()) 

295 tmpConnections[toRemove.datasetType] = list(tmpRefs) 

296 helper = AdjustQuantumHelper(inputs=tmpConnections, outputs=node.quantum.outputs) 

297 assert node.quantum.dataId is not None, ( 

298 "assert to make the type checker happy, it should not " 

299 "actually be possible to not have dataId set to None " 

300 "at this point" 

301 ) 

302 

303 # Try to adjust the quantum with the reduced refs to make sure the 

304 # node will still satisfy all its conditions. 

305 # 

306 # If it can't because NoWorkFound is raised, that means a 

307 # connection is no longer present, and the node should be removed 

308 # from the graph. 

309 try: 

310 helper.adjust_in_place(node.taskDef.connections, node.taskDef.label, node.quantum.dataId) 

311 newQuantum = Quantum( 

312 taskName=node.quantum.taskName, 

313 taskClass=node.quantum.taskClass, 

314 dataId=node.quantum.dataId, 

315 initInputs=node.quantum.initInputs, 

316 inputs=helper.inputs, 

317 outputs=helper.outputs, 

318 ) 

319 # If the inputs or outputs were adjusted to something different 

320 # than what was supplied by the graph builder, dissassociate 

321 # node from those refs, and if they are output refs, prune them 

322 # from downstream tasks. This means that based on new inputs 

323 # the task wants to produce fewer outputs, or consume fewer 

324 # inputs. 

325 for condition, existingMapping, newMapping, remover in ( 

326 ( 

327 helper.inputs_adjusted, 

328 node.quantum.inputs, 

329 helper.inputs, 

330 datasetRefDict.removeConsumer, 

331 ), 

332 ( 

333 helper.outputs_adjusted, 

334 node.quantum.outputs, 

335 helper.outputs, 

336 datasetRefDict.removeProducer, 

337 ), 

338 ): 

339 if condition: 

340 notNeeded = set() 

341 for key in existingMapping: 

342 if key not in newMapping: 

343 compositeRefs = ( 

344 r if not r.isComponent() else r.makeCompositeRef() 

345 for r in existingMapping[key] 

346 ) 

347 notNeeded |= set(compositeRefs) 

348 continue 

349 notNeeded |= set(existingMapping[key]) - set(newMapping[key]) 

350 if notNeeded: 

351 for ref in notNeeded: 

352 if ref.isComponent(): 

353 ref = ref.makeCompositeRef() 

354 remover(ref, node) 

355 if remover is datasetRefDict.removeProducer: 

356 _pruner(datasetRefDict, notNeeded, alreadyPruned=alreadyPruned) 

357 object.__setattr__(node, "quantum", newQuantum) 

358 noWorkFound = False 

359 

360 except NoWorkFound: 

361 noWorkFound = True 

362 

363 if noWorkFound: 

364 # This will throw if the length is less than the minimum number 

365 for tmpRef in chain( 

366 chain.from_iterable(node.quantum.inputs.values()), node.quantum.initInputs.values() 

367 ): 

368 if tmpRef.isComponent(): 

369 tmpRef = tmpRef.makeCompositeRef() 

370 datasetRefDict.removeConsumer(tmpRef, node) 

371 alreadyPruned.add(node) 

372 # prune all outputs produced by this node 

373 # mark that none of these will be produced 

374 forwardPrunes = set() 

375 for forwardRef in chain.from_iterable(node.quantum.outputs.values()): 

376 datasetRefDict.removeProducer(forwardRef, node) 

377 forwardPrunes.add(forwardRef) 

378 _pruner(datasetRefDict, forwardPrunes, alreadyPruned=alreadyPruned)