Coverage for python/lsst/pipe/base/graph/_implDetails.py: 15%

132 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-12 11:14 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("_DatasetTracker", "DatasetTypeName", "_pruner") 

24 

25from collections import defaultdict 

26from collections.abc import Iterable 

27from itertools import chain 

28from typing import Generic, NewType, TypeVar 

29 

30import networkx as nx 

31from lsst.daf.butler import DatasetRef, DatasetType, NamedKeyDict, Quantum 

32from lsst.pipe.base.connections import AdjustQuantumHelper 

33 

34from .._status import NoWorkFound 

35from ..pipeline import TaskDef 

36from .quantumNode import QuantumNode 

37 

38# NewTypes 

39DatasetTypeName = NewType("DatasetTypeName", str) 

40 

41# Generic type parameters 

42_T = TypeVar("_T", DatasetTypeName, DatasetRef) 

43_U = TypeVar("_U", TaskDef, QuantumNode) 

44 

45 

46class _DatasetTracker(Generic[_T, _U]): 

47 r"""A generic container for tracking keys which are produced or 

48 consumed by some value. In the context of a QuantumGraph, keys may be 

49 `~lsst.daf.butler.DatasetRef`\ s and the values would be Quanta that either 

50 produce or consume those `~lsst.daf.butler.DatasetRef`\ s. 

51 

52 Prameters 

53 --------- 

54 createInverse : bool 

55 When adding a key associated with a producer or consumer, also create 

56 and inverse mapping that allows looking up all the keys associated with 

57 some value. Defaults to False. 

58 """ 

59 

60 def __init__(self, createInverse: bool = False): 

61 self._producers: dict[_T, _U] = {} 

62 self._consumers: defaultdict[_T, set[_U]] = defaultdict(set) 

63 self._createInverse = createInverse 

64 if self._createInverse: 

65 self._itemsDict: defaultdict[_U, set[_T]] = defaultdict(set) 

66 

67 def addProducer(self, key: _T, value: _U) -> None: 

68 """Add a key which is produced by some value. 

69 

70 Parameters 

71 ---------- 

72 key : `~typing.TypeVar` 

73 The type to track. 

74 value : `~typing.TypeVar` 

75 The type associated with the production of the key. 

76 

77 Raises 

78 ------ 

79 ValueError 

80 Raised if key is already declared to be produced by another value. 

81 """ 

82 if (existing := self._producers.get(key)) is not None and existing != value: 

83 raise ValueError(f"Only one node is allowed to produce {key}, the current producer is {existing}") 

84 self._producers[key] = value 

85 if self._createInverse: 

86 self._itemsDict[value].add(key) 

87 

88 def removeProducer(self, key: _T, value: _U) -> None: 

89 """Remove a value (e.g. `QuantumNode` or `TaskDef`) from being 

90 considered a producer of the corresponding key. 

91 

92 It is not an error to remove a key that is not in the tracker. 

93 

94 Parameters 

95 ---------- 

96 key : `~typing.TypeVar` 

97 The type to track. 

98 value : `~typing.TypeVar` 

99 The type associated with the production of the key. 

100 """ 

101 self._producers.pop(key, None) 

102 if self._createInverse: 

103 if result := self._itemsDict.get(value): 

104 result.discard(key) 

105 

106 def addConsumer(self, key: _T, value: _U) -> None: 

107 """Add a key which is consumed by some value. 

108 

109 Parameters 

110 ---------- 

111 key : `~typing.TypeVar` 

112 The type to track. 

113 value : `~typing.TypeVar` 

114 The type associated with the consumption of the key. 

115 """ 

116 self._consumers[key].add(value) 

117 if self._createInverse: 

118 self._itemsDict[value].add(key) 

119 

120 def removeConsumer(self, key: _T, value: _U) -> None: 

121 """Remove a value (e.g. `QuantumNode` or `TaskDef`) from being 

122 considered a consumer of the corresponding key. 

123 

124 It is not an error to remove a key that is not in the tracker. 

125 

126 Parameters 

127 ---------- 

128 key : `~typing.TypeVar` 

129 The type to track. 

130 value : `~typing.TypeVar` 

131 The type associated with the consumption of the key. 

132 """ 

133 if (result := self._consumers.get(key)) is not None: 

134 result.discard(value) 

135 if self._createInverse: 

136 if result_inverse := self._itemsDict.get(value): 

137 result_inverse.discard(key) 

138 

139 def getConsumers(self, key: _T) -> set[_U]: 

140 """Return all values associated with the consumption of the supplied 

141 key. 

142 

143 Parameters 

144 ---------- 

145 key : `~typing.TypeVar` 

146 The type which has been tracked in the `_DatasetTracker`. 

147 """ 

148 return self._consumers.get(key, set()) 

149 

150 def getProducer(self, key: _T) -> _U | None: 

151 """Return the value associated with the consumption of the supplied 

152 key. 

153 

154 Parameters 

155 ---------- 

156 key : `~typing.TypeVar` 

157 The type which has been tracked in the `_DatasetTracker`. 

158 """ 

159 # This tracker may have had all nodes associated with a key removed 

160 # and if there are no refs (empty set) should return None 

161 return producer if (producer := self._producers.get(key)) else None 

162 

163 def getAll(self, key: _T) -> set[_U]: 

164 """Return all consumers and the producer associated with the the 

165 supplied key. 

166 

167 Parameters 

168 ---------- 

169 key : `~typing.TypeVar` 

170 The type which has been tracked in the `_DatasetTracker`. 

171 """ 

172 return self.getConsumers(key).union(x for x in (self.getProducer(key),) if x is not None) 

173 

174 @property 

175 def inverse(self) -> defaultdict[_U, set[_T]] | None: 

176 """Return the inverse mapping if class was instantiated to create an 

177 inverse, else return None. 

178 """ 

179 return self._itemsDict if self._createInverse else None 

180 

181 def makeNetworkXGraph(self) -> nx.DiGraph: 

182 """Create a NetworkX graph out of all the contained keys, using the 

183 relations of producer and consumers to create the edges. 

184 

185 Returns 

186 ------- 

187 graph : `networkx.DiGraph` 

188 The graph created out of the supplied keys and their relations. 

189 """ 

190 graph = nx.DiGraph() 

191 for entry in self._producers.keys() | self._consumers.keys(): 

192 producer = self.getProducer(entry) 

193 consumers = self.getConsumers(entry) 

194 # This block is for tasks that consume existing inputs 

195 if producer is None and consumers: 

196 for consumer in consumers: 

197 graph.add_node(consumer) 

198 # This block is for tasks that produce output that is not consumed 

199 # in this graph 

200 elif producer is not None and not consumers: 

201 graph.add_node(producer) 

202 # all other connections 

203 else: 

204 for consumer in consumers: 

205 graph.add_edge(producer, consumer) 

206 return graph 

207 

208 def keys(self) -> set[_T]: 

209 """Return all tracked keys.""" 

210 return self._producers.keys() | self._consumers.keys() 

211 

212 def remove(self, key: _T) -> None: 

213 """Remove a key and its corresponding value from the tracker, this is 

214 a no-op if the key is not in the tracker. 

215 

216 Parameters 

217 ---------- 

218 key : `~typing.TypeVar` 

219 A key tracked by the `_DatasetTracker`. 

220 """ 

221 self._producers.pop(key, None) 

222 self._consumers.pop(key, None) 

223 

224 def __contains__(self, key: _T) -> bool: 

225 """Check if a key is in the `_DatasetTracker`. 

226 

227 Parameters 

228 ---------- 

229 key : `~typing.TypeVar` 

230 The key to check. 

231 

232 Returns 

233 ------- 

234 contains : `bool` 

235 Boolean of the presence of the supplied key. 

236 """ 

237 return key in self._producers or key in self._consumers 

238 

239 

240def _pruner( 

241 datasetRefDict: _DatasetTracker[DatasetRef, QuantumNode], 

242 refsToRemove: Iterable[DatasetRef], 

243 *, 

244 alreadyPruned: set[QuantumNode] | None = None, 

245) -> None: 

246 r"""Prune supplied dataset refs out of ``datasetRefDict`` container, 

247 recursing to additional nodes dependant on pruned refs. 

248 

249 Parameters 

250 ---------- 

251 datasetRefDict : `_DatasetTracker` [ `~lsst.daf.butler.DatasetRef`, \ 

252 `QuantumNode`] 

253 The dataset tracker that maps `~lsst.daf.butler.DatasetRef`\ s to the 

254 `QuantumNode`\s that produce/consume that 

255 `~lsst.daf.butler.DatasetRef`. 

256 This function modifies ``datasetRefDict`` in-place. 

257 refsToRemove : `~collections.abc.Iterable` of `~lsst.daf.butler.DatasetRef` 

258 The `~lsst.daf.butler.DatasetRef`\ s which should be pruned from the 

259 input dataset tracker. 

260 alreadyPruned : `set` of `QuantumNode` 

261 A set of nodes which have been pruned from the dataset tracker. 

262 """ 

263 if alreadyPruned is None: 

264 alreadyPruned = set() 

265 for ref in refsToRemove: 

266 # make a copy here, because this structure will be modified in 

267 # recursion, hitting a node more than once won't be much of an 

268 # issue, as we skip anything that has been processed 

269 nodes = set(datasetRefDict.getConsumers(ref)) 

270 for node in nodes: 

271 # This node will never be associated with this ref 

272 datasetRefDict.removeConsumer(ref, node) 

273 if node in alreadyPruned: 

274 continue 

275 # find the connection corresponding to the input ref 

276 connectionRefs = node.quantum.inputs.get(ref.datasetType) 

277 if connectionRefs is None: 

278 # look to see if any inputs are component refs that match the 

279 # input ref to prune 

280 others = ref.datasetType.makeAllComponentDatasetTypes() 

281 # for each other component type check if there are assocated 

282 # refs 

283 for other in others: 

284 connectionRefs = node.quantum.inputs.get(other) 

285 if connectionRefs is not None: 

286 # now search the component refs and see which one 

287 # matches the ref to trim 

288 for cr in connectionRefs: 

289 if cr.makeCompositeRef() == ref: 

290 toRemove = cr 

291 break 

292 else: 

293 # Ref must be an initInput ref and we want to ignore those 

294 raise RuntimeError(f"Cannot prune on non-Input dataset type {ref.datasetType.name}") 

295 else: 

296 toRemove = ref 

297 

298 tmpRefs = set(connectionRefs).difference((toRemove,)) 

299 tmpConnections = NamedKeyDict[DatasetType, list[DatasetRef]](node.quantum.inputs.items()) 

300 tmpConnections[toRemove.datasetType] = list(tmpRefs) 

301 helper = AdjustQuantumHelper(inputs=tmpConnections, outputs=node.quantum.outputs) 

302 assert node.quantum.dataId is not None, ( 

303 "assert to make the type checker happy, it should not " 

304 "actually be possible to not have dataId set to None " 

305 "at this point" 

306 ) 

307 

308 # Try to adjust the quantum with the reduced refs to make sure the 

309 # node will still satisfy all its conditions. 

310 # 

311 # If it can't because NoWorkFound is raised, that means a 

312 # connection is no longer present, and the node should be removed 

313 # from the graph. 

314 try: 

315 helper.adjust_in_place(node.taskDef.connections, node.taskDef.label, node.quantum.dataId) 

316 # ignore the types because quantum really can take a sequence 

317 # of inputs 

318 newQuantum = Quantum( 

319 taskName=node.quantum.taskName, 

320 taskClass=node.quantum.taskClass, 

321 dataId=node.quantum.dataId, 

322 initInputs=node.quantum.initInputs, 

323 inputs=helper.inputs, 

324 outputs=helper.outputs, 

325 ) 

326 # If the inputs or outputs were adjusted to something different 

327 # than what was supplied by the graph builder, dissassociate 

328 # node from those refs, and if they are output refs, prune them 

329 # from downstream tasks. This means that based on new inputs 

330 # the task wants to produce fewer outputs, or consume fewer 

331 # inputs. 

332 for condition, existingMapping, newMapping, remover in ( 

333 ( 

334 helper.inputs_adjusted, 

335 node.quantum.inputs, 

336 helper.inputs, 

337 datasetRefDict.removeConsumer, 

338 ), 

339 ( 

340 helper.outputs_adjusted, 

341 node.quantum.outputs, 

342 helper.outputs, 

343 datasetRefDict.removeProducer, 

344 ), 

345 ): 

346 if condition: 

347 notNeeded = set() 

348 for key in existingMapping: 

349 if key not in newMapping: 

350 compositeRefs = ( 

351 r if not r.isComponent() else r.makeCompositeRef() 

352 for r in existingMapping[key] 

353 ) 

354 notNeeded |= set(compositeRefs) 

355 continue 

356 notNeeded |= set(existingMapping[key]) - set(newMapping[key]) 

357 if notNeeded: 

358 for ref in notNeeded: 

359 if ref.isComponent(): 

360 ref = ref.makeCompositeRef() 

361 remover(ref, node) 

362 if remover is datasetRefDict.removeProducer: 

363 _pruner(datasetRefDict, notNeeded, alreadyPruned=alreadyPruned) 

364 object.__setattr__(node, "quantum", newQuantum) 

365 noWorkFound = False 

366 

367 except NoWorkFound: 

368 noWorkFound = True 

369 

370 if noWorkFound: 

371 # This will throw if the length is less than the minimum number 

372 for tmpRef in chain( 

373 chain.from_iterable(node.quantum.inputs.values()), node.quantum.initInputs.values() 

374 ): 

375 if tmpRef.isComponent(): 

376 tmpRef = tmpRef.makeCompositeRef() 

377 datasetRefDict.removeConsumer(tmpRef, node) 

378 alreadyPruned.add(node) 

379 # prune all outputs produced by this node 

380 # mark that none of these will be produced 

381 forwardPrunes = set() 

382 for forwardRef in chain.from_iterable(node.quantum.outputs.values()): 

383 datasetRefDict.removeProducer(forwardRef, node) 

384 forwardPrunes.add(forwardRef) 

385 _pruner(datasetRefDict, forwardPrunes, alreadyPruned=alreadyPruned)