Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for cmdLineFwk module. 

23""" 

24 

25import logging 

26import networkx as nx 

27from multiprocessing import Manager 

28import psutil 

29import time 

30from types import SimpleNamespace 

31import unittest 

32 

33from lsst.ctrl.mpexec import MPGraphExecutor, MPGraphExecutorError, MPTimeoutError, QuantumExecutor 

34from lsst.ctrl.mpexec.execFixupDataId import ExecFixupDataId 

35from lsst.pipe.base import NodeId 

36 

37 

38logging.basicConfig(level=logging.DEBUG) 

39 

40_LOG = logging.getLogger(__name__) 

41 

42 

43class QuantumExecutorMock(QuantumExecutor): 

44 """Mock class for QuantumExecutor 

45 """ 

46 def __init__(self, mp=False): 

47 self.quanta = [] 

48 if mp: 

49 # in multiprocess mode use shared list 

50 manager = Manager() 

51 self.quanta = manager.list() 

52 

53 def execute(self, taskDef, quantum, butler): 

54 _LOG.debug("QuantumExecutorMock.execute: taskDef=%s dataId=%s", taskDef, quantum.dataId) 

55 if taskDef.taskClass: 

56 # only works for TaskMockMP class below 

57 taskDef.taskClass().runQuantum() 

58 self.quanta.append(quantum) 

59 

60 def getDataIds(self, field): 

61 """Returns values for dataId field for each visited quanta""" 

62 return [quantum.dataId[field] for quantum in self.quanta] 

63 

64 

65class QuantumIterDataMock: 

66 """Simple class to mock QuantumIterData. 

67 """ 

68 def __init__(self, index, taskDef, **dataId): 

69 self.index = index 

70 self.taskDef = taskDef 

71 self.quantum = SimpleNamespace(dataId=dataId) 

72 self.dependencies = set() 

73 self.nodeId = NodeId(index, "DummyBuildString") 

74 

75 

76class QuantumGraphMock: 

77 """Mock for quantum graph. 

78 """ 

79 def __init__(self, qdata): 

80 self._graph = nx.DiGraph() 

81 previous = qdata[0] 

82 for node in qdata[1:]: 

83 self._graph.add_edge(previous, node) 

84 previous = node 

85 

86 def __iter__(self): 

87 yield from nx.topological_sort(self._graph) 

88 

89 def findTaskDefByLabel(self, label): 

90 for q in self: 

91 if q.taskDef.label == label: 

92 return q.taskDef 

93 

94 def quantaForTask(self, taskDef): 

95 quanta = set() 

96 for q in self: 

97 if q.taskDef == taskDef: 

98 quanta.add(q) 

99 return quanta 

100 

101 @property 

102 def graph(self): 

103 return self._graph 

104 

105 def findCycle(self): 

106 return [] 

107 

108 def determineInputsToQuantumNode(self, node): 

109 result = set() 

110 for n in node.dependencies: 

111 for otherNode in self: 

112 if otherNode.index == n: 

113 result.add(otherNode) 

114 return result 

115 

116 

117class TaskMockMP: 

118 """Simple mock class for task supporting multiprocessing. 

119 """ 

120 canMultiprocess = True 

121 

122 def runQuantum(self): 

123 _LOG.debug("TaskMockMP.runQuantum") 

124 pass 

125 

126 

127class TaskMockFail: 

128 """Simple mock class for task which fails. 

129 """ 

130 canMultiprocess = True 

131 

132 def runQuantum(self): 

133 _LOG.debug("TaskMockFail.runQuantum") 

134 raise ValueError("expected failure") 

135 

136 

137class TaskMockSleep: 

138 """Simple mock class for task which fails. 

139 """ 

140 canMultiprocess = True 

141 

142 def runQuantum(self): 

143 _LOG.debug("TaskMockSleep.runQuantum") 

144 time.sleep(3.) 

145 

146 

147class TaskMockNoMP: 

148 """Simple mock class for task not supporting multiprocessing. 

149 """ 

150 canMultiprocess = False 

151 

152 

153class TaskDefMock: 

154 """Simple mock class for task definition in a pipeline. 

155 """ 

156 def __init__(self, taskName="Task", config=None, taskClass=TaskMockMP, label="task1"): 

157 self.taskName = taskName 

158 self.config = config 

159 self.taskClass = taskClass 

160 self.label = label 

161 

162 def __str__(self): 

163 return f"TaskDefMock(taskName={self.taskName}, taskClass={self.taskClass.__name__})" 

164 

165 

166class MPGraphExecutorTestCase(unittest.TestCase): 

167 """A test case for MPGraphExecutor class 

168 """ 

169 

170 def test_mpexec_nomp(self): 

171 """Make simple graph and execute""" 

172 

173 taskDef = TaskDefMock() 

174 qgraph = QuantumGraphMock([ 

175 QuantumIterDataMock(index=i, taskDef=taskDef, detector=i) for i in range(3) 

176 ]) 

177 

178 # run in single-process mode 

179 qexec = QuantumExecutorMock() 

180 mpexec = MPGraphExecutor(numProc=1, timeout=100, quantumExecutor=qexec) 

181 mpexec.execute(qgraph, butler=None) 

182 self.assertEqual(qexec.getDataIds("detector"), [0, 1, 2]) 

183 

184 def test_mpexec_mp(self): 

185 """Make simple graph and execute""" 

186 

187 taskDef = TaskDefMock() 

188 qgraph = QuantumGraphMock([ 

189 QuantumIterDataMock(index=i, taskDef=taskDef, detector=i) for i in range(3) 

190 ]) 

191 

192 # run in multi-process mode, the order of results is not defined 

193 qexec = QuantumExecutorMock(mp=True) 

194 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

195 mpexec.execute(qgraph, butler=None) 

196 self.assertCountEqual(qexec.getDataIds("detector"), [0, 1, 2]) 

197 

198 def test_mpexec_nompsupport(self): 

199 """Try to run MP for task that has no MP support which should fail 

200 """ 

201 

202 taskDef = TaskDefMock(taskClass=TaskMockNoMP) 

203 qgraph = QuantumGraphMock([ 

204 QuantumIterDataMock(index=i, taskDef=taskDef, detector=i) for i in range(3) 

205 ]) 

206 

207 # run in multi-process mode 

208 qexec = QuantumExecutorMock() 

209 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

210 with self.assertRaises(MPGraphExecutorError): 

211 mpexec.execute(qgraph, butler=None) 

212 

213 def test_mpexec_fixup(self): 

214 """Make simple graph and execute, add dependencies by executing fixup code 

215 """ 

216 

217 taskDef = TaskDefMock() 

218 

219 for reverse in (False, True): 

220 qgraph = QuantumGraphMock([ 

221 QuantumIterDataMock(index=i, taskDef=taskDef, detector=i) for i in range(3) 

222 ]) 

223 

224 qexec = QuantumExecutorMock() 

225 fixup = ExecFixupDataId("task1", "detector", reverse=reverse) 

226 mpexec = MPGraphExecutor(numProc=1, timeout=100, quantumExecutor=qexec, 

227 executionGraphFixup=fixup) 

228 mpexec.execute(qgraph, butler=None) 

229 

230 expected = [0, 1, 2] 

231 if reverse: 

232 expected = list(reversed(expected)) 

233 self.assertEqual(qexec.getDataIds("detector"), expected) 

234 

235 def test_mpexec_timeout(self): 

236 """Fail due to timeout""" 

237 

238 taskDef = TaskDefMock() 

239 taskDefSleep = TaskDefMock(taskClass=TaskMockSleep) 

240 qgraph = QuantumGraphMock([ 

241 QuantumIterDataMock(index=0, taskDef=taskDef, detector=0), 

242 QuantumIterDataMock(index=1, taskDef=taskDefSleep, detector=1), 

243 QuantumIterDataMock(index=2, taskDef=taskDef, detector=2), 

244 ]) 

245 

246 # with failFast we'll get immediate MPTimeoutError 

247 qexec = QuantumExecutorMock(mp=True) 

248 mpexec = MPGraphExecutor(numProc=3, timeout=1, quantumExecutor=qexec, failFast=True) 

249 with self.assertRaises(MPTimeoutError): 

250 mpexec.execute(qgraph, butler=None) 

251 

252 # with failFast=False exception happens after last task finishes 

253 qexec = QuantumExecutorMock(mp=True) 

254 mpexec = MPGraphExecutor(numProc=3, timeout=1, quantumExecutor=qexec, failFast=False) 

255 with self.assertRaises(MPTimeoutError): 

256 mpexec.execute(qgraph, butler=None) 

257 self.assertCountEqual(qexec.getDataIds("detector"), [0, 2]) 

258 

259 def test_mpexec_failure(self): 

260 """Failure in one task should not stop other tasks""" 

261 

262 taskDef = TaskDefMock() 

263 taskDefFail = TaskDefMock(taskClass=TaskMockFail) 

264 qgraph = QuantumGraphMock([ 

265 QuantumIterDataMock(index=0, taskDef=taskDef, detector=0), 

266 QuantumIterDataMock(index=1, taskDef=taskDefFail, detector=1), 

267 QuantumIterDataMock(index=2, taskDef=taskDef, detector=2), 

268 ]) 

269 

270 qexec = QuantumExecutorMock(mp=True) 

271 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

272 with self.assertRaises(MPGraphExecutorError): 

273 mpexec.execute(qgraph, butler=None) 

274 self.assertCountEqual(qexec.getDataIds("detector"), [0, 2]) 

275 

276 def test_mpexec_failure_dep(self): 

277 """Failure in one task should skip dependents""" 

278 

279 taskDef = TaskDefMock() 

280 taskDefFail = TaskDefMock(taskClass=TaskMockFail) 

281 qdata = [ 

282 QuantumIterDataMock(index=0, taskDef=taskDef, detector=0), 

283 QuantumIterDataMock(index=1, taskDef=taskDefFail, detector=1), 

284 QuantumIterDataMock(index=2, taskDef=taskDef, detector=2), 

285 QuantumIterDataMock(index=3, taskDef=taskDef, detector=3), 

286 QuantumIterDataMock(index=4, taskDef=taskDef, detector=4), 

287 ] 

288 qdata[2].dependencies.add(1) 

289 qdata[4].dependencies.add(3) 

290 qdata[4].dependencies.add(2) 

291 

292 qgraph = QuantumGraphMock(qdata) 

293 

294 qexec = QuantumExecutorMock(mp=True) 

295 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

296 with self.assertRaises(MPGraphExecutorError): 

297 mpexec.execute(qgraph, butler=None) 

298 self.assertCountEqual(qexec.getDataIds("detector"), [0, 3]) 

299 

300 def test_mpexec_failure_failfast(self): 

301 """Fast fail stops quickly. 

302 

303 Timing delay of task #3 should be sufficient to process 

304 failure and raise exception. 

305 """ 

306 

307 taskDef = TaskDefMock() 

308 taskDefFail = TaskDefMock(taskClass=TaskMockFail) 

309 taskDefSleep = TaskDefMock(taskClass=TaskMockSleep) 

310 qdata = [ 

311 QuantumIterDataMock(index=0, taskDef=taskDef, detector=0), 

312 QuantumIterDataMock(index=1, taskDef=taskDefFail, detector=1), 

313 QuantumIterDataMock(index=2, taskDef=taskDef, detector=2), 

314 QuantumIterDataMock(index=3, taskDef=taskDefSleep, detector=3), 

315 QuantumIterDataMock(index=4, taskDef=taskDef, detector=4), 

316 ] 

317 qdata[1].dependencies.add(0) 

318 qdata[2].dependencies.add(1) 

319 qdata[4].dependencies.add(3) 

320 qdata[4].dependencies.add(2) 

321 

322 qgraph = QuantumGraphMock(qdata) 

323 

324 qexec = QuantumExecutorMock(mp=True) 

325 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec, failFast=True) 

326 with self.assertRaises(MPGraphExecutorError): 

327 mpexec.execute(qgraph, butler=None) 

328 self.assertCountEqual(qexec.getDataIds("detector"), [0]) 

329 

330 def test_mpexec_num_fd(self): 

331 """Check that number of open files stays reasonable 

332 """ 

333 

334 taskDef = TaskDefMock() 

335 qgraph = QuantumGraphMock([ 

336 QuantumIterDataMock(index=i, taskDef=taskDef, detector=i) for i in range(20) 

337 ]) 

338 

339 this_proc = psutil.Process() 

340 num_fds_0 = this_proc.num_fds() 

341 

342 # run in multi-process mode, the order of results is not defined 

343 qexec = QuantumExecutorMock(mp=True) 

344 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

345 mpexec.execute(qgraph, butler=None) 

346 

347 num_fds_1 = this_proc.num_fds() 

348 # They should be the same but allow small growth just in case. 

349 # Without DM-26728 fix the difference would be equal to number of 

350 # quanta (20). 

351 self.assertLess(num_fds_1 - num_fds_0, 5) 

352 

353 

354if __name__ == "__main__": 354 ↛ 355line 354 didn't jump to line 355, because the condition on line 354 was never true

355 unittest.main()