Coverage for tests/test_quantumGraph.py: 30%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

283 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import os 

23import pickle 

24import random 

25import tempfile 

26import unittest 

27import uuid 

28from itertools import chain 

29from typing import Iterable 

30 

31import lsst.pipe.base.connectionTypes as cT 

32import lsst.utils.tests 

33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

34from lsst.pex.config import Field 

35from lsst.pipe.base import ( 

36 DatasetTypeName, 

37 PipelineTask, 

38 PipelineTaskConfig, 

39 PipelineTaskConnections, 

40 QuantumGraph, 

41 TaskDef, 

42) 

43from lsst.pipe.base.graph.quantumNode import QuantumNode 

44 

45try: 

46 import boto3 

47 from moto import mock_s3 

48except ImportError: 

49 boto3 = None 

50 

51 def mock_s3(cls): 

52 """A no-op decorator in case moto mock_s3 can not be imported.""" 

53 return cls 

54 

55 

56METADATA = {"a": [1, 2, 3]} 

57 

58 

59class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

60 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

61 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

62 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

63 

64 

65class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

66 conf1 = Field(dtype=int, default=1, doc="dummy config") 

67 

68 

69class Dummy1PipelineTask(PipelineTask): 

70 ConfigClass = Dummy1Config 

71 

72 

73class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

74 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

75 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

76 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

77 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

78 

79 

80class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

81 conf1 = Field(dtype=int, default=1, doc="dummy config") 

82 

83 

84class Dummy2PipelineTask(PipelineTask): 

85 ConfigClass = Dummy2Config 

86 

87 

88class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

89 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

90 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a") 

91 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

92 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

93 

94 

95class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

96 conf1 = Field(dtype=int, default=1, doc="dummy config") 

97 

98 

99class Dummy3PipelineTask(PipelineTask): 

100 ConfigClass = Dummy3Config 

101 

102 

103# Test if a Task that does not interact with the other Tasks works fine in 

104# the graph. 

105class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")): 

106 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

107 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

108 

109 

110class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections): 

111 conf1 = Field(dtype=int, default=1, doc="dummy config") 

112 

113 

114class Dummy4PipelineTask(PipelineTask): 

115 ConfigClass = Dummy4Config 

116 

117 

118class QuantumGraphTestCase(unittest.TestCase): 

119 """Tests the various functions of a quantum graph""" 

120 

121 def setUp(self): 

122 config = Config( 

123 { 

124 "version": 1, 

125 "skypix": { 

126 "common": "htm7", 

127 "htm": { 

128 "class": "lsst.sphgeom.HtmPixelization", 

129 "max_level": 24, 

130 }, 

131 }, 

132 "elements": { 

133 "A": { 

134 "keys": [ 

135 { 

136 "name": "id", 

137 "type": "int", 

138 } 

139 ], 

140 "storage": { 

141 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

142 }, 

143 }, 

144 "B": { 

145 "keys": [ 

146 { 

147 "name": "id", 

148 "type": "int", 

149 } 

150 ], 

151 "storage": { 

152 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

153 }, 

154 }, 

155 }, 

156 "packers": {}, 

157 } 

158 ) 

159 universe = DimensionUniverse(config=config) 

160 # need to make a mapping of TaskDef to set of quantum 

161 quantumMap = {} 

162 tasks = [] 

163 for task, label in ( 

164 (Dummy1PipelineTask, "R"), 

165 (Dummy2PipelineTask, "S"), 

166 (Dummy3PipelineTask, "T"), 

167 (Dummy4PipelineTask, "U"), 

168 ): 

169 config = task.ConfigClass() 

170 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label) 

171 tasks.append(taskDef) 

172 quantumSet = set() 

173 connections = taskDef.connections 

174 for a, b in ((1, 2), (3, 4)): 

175 if connections.initInputs: 

176 initInputDSType = DatasetType( 

177 connections.initInput.name, 

178 tuple(), 

179 storageClass=connections.initInput.storageClass, 

180 universe=universe, 

181 ) 

182 initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))] 

183 else: 

184 initRefs = None 

185 inputDSType = DatasetType( 

186 connections.input.name, 

187 connections.input.dimensions, 

188 storageClass=connections.input.storageClass, 

189 universe=universe, 

190 ) 

191 inputRefs = [ 

192 DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) 

193 ] 

194 outputDSType = DatasetType( 

195 connections.output.name, 

196 connections.output.dimensions, 

197 storageClass=connections.output.storageClass, 

198 universe=universe, 

199 ) 

200 outputRefs = [ 

201 DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) 

202 ] 

203 quantumSet.add( 

204 Quantum( 

205 taskName=task.__qualname__, 

206 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

207 taskClass=task, 

208 initInputs=initRefs, 

209 inputs={inputDSType: inputRefs}, 

210 outputs={outputDSType: outputRefs}, 

211 ) 

212 ) 

213 quantumMap[taskDef] = quantumSet 

214 self.tasks = tasks 

215 self.quantumMap = quantumMap 

216 self.qGraph = QuantumGraph(quantumMap, metadata=METADATA) 

217 self.universe = universe 

218 

219 def _cleanGraphs(self, graph1, graph2): 

220 # This is a hack for the unit test since the qualified name will be 

221 # different as it will be __main__ here, but qualified to the 

222 # unittest module name when restored 

223 # Updates in place 

224 for saved, loaded in zip(graph1.taskGraph, graph2.taskGraph): 

225 saved.taskName = saved.taskName.split(".")[-1] 

226 loaded.taskName = loaded.taskName.split(".")[-1] 

227 

228 def testTaskGraph(self): 

229 for taskDef in self.quantumMap.keys(): 

230 self.assertIn(taskDef, self.qGraph.taskGraph) 

231 

232 def testGraph(self): 

233 graphSet = {q.quantum for q in self.qGraph.graph} 

234 for quantum in chain.from_iterable(self.quantumMap.values()): 

235 self.assertIn(quantum, graphSet) 

236 

237 def testGetQuantumNodeByNodeId(self): 

238 inputQuanta = tuple(self.qGraph.inputQuanta) 

239 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

240 self.assertEqual(node, inputQuanta[0]) 

241 wrongNode = uuid.uuid4() 

242 with self.assertRaises(KeyError): 

243 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

244 

245 def testPickle(self): 

246 stringify = pickle.dumps(self.qGraph) 

247 restore: QuantumGraph = pickle.loads(stringify) 

248 self._cleanGraphs(self.qGraph, restore) 

249 self.assertEqual(self.qGraph, restore) 

250 

251 def testInputQuanta(self): 

252 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

253 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs) 

254 

255 def testOutputtQuanta(self): 

256 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

257 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs) 

258 

259 def testLength(self): 

260 self.assertEqual(len(self.qGraph), 2 * len(self.tasks)) 

261 

262 def testGetQuantaForTask(self): 

263 for task in self.tasks: 

264 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

265 

266 def testGetNodesForTask(self): 

267 for task in self.tasks: 

268 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task) 

269 quanta_in_node = set(n.quantum for n in nodes) 

270 self.assertEqual(quanta_in_node, self.quantumMap[task]) 

271 

272 def testFindTasksWithInput(self): 

273 self.assertEqual( 

274 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1] 

275 ) 

276 

277 def testFindTasksWithOutput(self): 

278 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

279 

280 def testTaskWithDSType(self): 

281 self.assertEqual( 

282 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2]) 

283 ) 

284 

285 def testFindTaskDefByName(self): 

286 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0]) 

287 

288 def testFindTaskDefByLabel(self): 

289 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0]) 

290 

291 def testFindQuantaWIthDSType(self): 

292 self.assertEqual( 

293 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]] 

294 ) 

295 

296 def testAllDatasetTypes(self): 

297 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

298 truth = set() 

299 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections): 

300 for connection in conClass.allConnections.values(): # type: ignore 

301 if not isinstance(connection, cT.InitOutput): 

302 truth.add(connection.name) 

303 self.assertEqual(allDatasetTypes, truth) 

304 

305 def testSubset(self): 

306 allNodes = list(self.qGraph) 

307 subset = self.qGraph.subset(allNodes[0]) 

308 self.assertEqual(len(subset), 1) 

309 subsetList = list(subset) 

310 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) 

311 self.assertEqual(self.qGraph._buildId, subset._buildId) 

312 

313 def testIsConnected(self): 

314 # False because there are two quantum chains for two distinct sets of 

315 # dimensions 

316 self.assertFalse(self.qGraph.isConnected) 

317 # make a broken subset 

318 filteredNodes = [n for n in self.qGraph if n.taskDef.label != "U"] 

319 subset = self.qGraph.subset((filteredNodes[0], filteredNodes[1])) 

320 # True because we subset to only one chain of graphs 

321 self.assertTrue(subset.isConnected) 

322 

323 def testSubsetToConnected(self): 

324 connectedGraphs = self.qGraph.subsetToConnected() 

325 self.assertEqual(len(connectedGraphs), 4) 

326 self.assertTrue(connectedGraphs[0].isConnected) 

327 self.assertTrue(connectedGraphs[1].isConnected) 

328 self.assertTrue(connectedGraphs[2].isConnected) 

329 self.assertTrue(connectedGraphs[3].isConnected) 

330 

331 # Split out task[3] because it is expected to be on its own 

332 for cg in connectedGraphs: 

333 if self.tasks[3] in cg.taskGraph: 

334 self.assertEqual(len(cg), 1) 

335 else: 

336 self.assertEqual(len(cg), 3) 

337 

338 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

339 

340 count = 0 

341 for node in self.qGraph: 

342 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

343 count += 1 

344 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

345 count += 1 

346 if connectedGraphs[2].checkQuantumInGraph(node.quantum): 

347 count += 1 

348 if connectedGraphs[3].checkQuantumInGraph(node.quantum): 

349 count += 1 

350 self.assertEqual(len(self.qGraph), count) 

351 

352 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs} 

353 for setLen, tskSet in taskSets.items(): 

354 if setLen == 3: 

355 self.assertEqual(set(self.tasks[:-1]), tskSet) 

356 elif setLen == 1: 

357 self.assertEqual({self.tasks[-1]}, tskSet) 

358 for cg in connectedGraphs: 

359 if len(cg.taskGraph) == 1: 

360 continue 

361 allNodes = list(cg) 

362 node = cg.determineInputsToQuantumNode(allNodes[1]) 

363 self.assertEqual(set([allNodes[0]]), node) 

364 node = cg.determineInputsToQuantumNode(allNodes[1]) 

365 self.assertEqual(set([allNodes[0]]), node) 

366 

367 def testDetermineOutputsOfQuantumNode(self): 

368 testNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

369 matchNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

370 connections = set() 

371 for node in testNodes: 

372 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node)) 

373 self.assertEqual(matchNodes, connections) 

374 

375 def testDetermineConnectionsOfQuantum(self): 

376 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

377 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2]) 

378 # outputs contain nodes tested for because it is a complete graph 

379 matchNodes |= set(testNodes) 

380 connections = set() 

381 for node in testNodes: 

382 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node)) 

383 self.assertEqual(matchNodes, connections) 

384 

385 def testDetermineAnsestorsOfQuantumNode(self): 

386 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

387 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

388 matchNodes |= set(testNodes) 

389 connections = set() 

390 for node in testNodes: 

391 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node)) 

392 self.assertEqual(matchNodes, connections) 

393 

394 def testFindCycle(self): 

395 self.assertFalse(self.qGraph.findCycle()) 

396 

397 def testSaveLoad(self): 

398 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: 

399 self.qGraph.save(tmpFile) 

400 tmpFile.seek(0) 

401 restore = QuantumGraph.load(tmpFile, self.universe) 

402 self._cleanGraphs(self.qGraph, restore) 

403 self.assertEqual(self.qGraph, restore) 

404 # Load in just one node 

405 tmpFile.seek(0) 

406 nodeId = [n.nodeId for n in self.qGraph][0] 

407 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,)) 

408 self.assertEqual(len(restoreSub), 1) 

409 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId)) 

410 

411 def testSaveLoadUri(self): 

412 uri = None 

413 try: 

414 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: 

415 uri = tmpFile.name 

416 self.qGraph.saveUri(uri) 

417 restore = QuantumGraph.loadUri(uri, self.universe) 

418 self.assertEqual(restore.metadata, METADATA) 

419 self._cleanGraphs(self.qGraph, restore) 

420 self.assertEqual(self.qGraph, restore) 

421 nodeNumberId = random.randint(0, len(self.qGraph) - 1) 

422 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId] 

423 restoreSub = QuantumGraph.loadUri( 

424 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId 

425 ) 

426 self.assertEqual(len(restoreSub), 1) 

427 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber)) 

428 # verify that more than one node works 

429 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

430 # ensure it is a different node number 

431 while nodeNumberId2 == nodeNumberId: 

432 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

433 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2] 

434 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) 

435 self.assertEqual(len(restoreSub), 2) 

436 self.assertEqual( 

437 set(restoreSub), 

438 set( 

439 ( 

440 restore.getQuantumNodeByNodeId(nodeNumber), 

441 restore.getQuantumNodeByNodeId(nodeNumber2), 

442 ) 

443 ), 

444 ) 

445 # verify an error when requesting a non existant node number 

446 with self.assertRaises(ValueError): 

447 QuantumGraph.loadUri(uri, self.universe, nodes=(99,)) 

448 

449 # verify a graphID that does not match will be an error 

450 with self.assertRaises(ValueError): 

451 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT") 

452 

453 except Exception as e: 

454 raise e 

455 finally: 

456 if uri is not None: 

457 os.remove(uri) 

458 

459 with self.assertRaises(TypeError): 

460 self.qGraph.saveUri("test.notgraph") 

461 

462 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!") 

463 @mock_s3 

464 def testSaveLoadUriS3(self): 

465 # Test loading a quantum graph from an mock s3 store 

466 conn = boto3.resource("s3", region_name="us-east-1") 

467 conn.create_bucket(Bucket="testBucket") 

468 uri = "s3://testBucket/qgraph.qgraph" 

469 self.qGraph.saveUri(uri) 

470 restore = QuantumGraph.loadUri(uri, self.universe) 

471 self._cleanGraphs(self.qGraph, restore) 

472 self.assertEqual(self.qGraph, restore) 

473 nodeId = list(self.qGraph)[0].nodeId 

474 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeId,)) 

475 self.assertEqual(len(restoreSub), 1) 

476 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId)) 

477 

478 def testContains(self): 

479 firstNode = next(iter(self.qGraph)) 

480 self.assertIn(firstNode, self.qGraph) 

481 

482 

483class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

484 pass 

485 

486 

487def setup_module(module): 

488 lsst.utils.tests.init() 

489 

490 

491if __name__ == "__main__": 491 ↛ 492line 491 didn't jump to line 492, because the condition on line 491 was never true

492 lsst.utils.tests.init() 

493 unittest.main()