Coverage for tests/test_quantumGraph.py: 31%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

272 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import os 

23import pickle 

24import random 

25import tempfile 

26import unittest 

27import uuid 

28from itertools import chain 

29from typing import Iterable 

30 

31import lsst.pipe.base.connectionTypes as cT 

32import lsst.utils.tests 

33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

34from lsst.pex.config import Field 

35from lsst.pipe.base import ( 

36 DatasetTypeName, 

37 PipelineTask, 

38 PipelineTaskConfig, 

39 PipelineTaskConnections, 

40 QuantumGraph, 

41 TaskDef, 

42) 

43from lsst.pipe.base.graph.quantumNode import QuantumNode 

44from lsst.utils.introspection import get_full_type_name 

45 

46try: 

47 import boto3 

48 from moto import mock_s3 

49except ImportError: 

50 boto3 = None 

51 

52 def mock_s3(cls): 

53 """A no-op decorator in case moto mock_s3 can not be imported.""" 

54 return cls 

55 

56 

57METADATA = {"a": [1, 2, 3]} 

58 

59 

60class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

61 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

62 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

63 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

64 

65 

66class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

67 conf1 = Field(dtype=int, default=1, doc="dummy config") 

68 

69 

70class Dummy1PipelineTask(PipelineTask): 

71 ConfigClass = Dummy1Config 

72 

73 

74class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

75 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

76 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

77 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

78 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

79 

80 

81class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

82 conf1 = Field(dtype=int, default=1, doc="dummy config") 

83 

84 

85class Dummy2PipelineTask(PipelineTask): 

86 ConfigClass = Dummy2Config 

87 

88 

89class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

90 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

91 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a") 

92 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

93 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

94 

95 

96class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

97 conf1 = Field(dtype=int, default=1, doc="dummy config") 

98 

99 

100class Dummy3PipelineTask(PipelineTask): 

101 ConfigClass = Dummy3Config 

102 

103 

104# Test if a Task that does not interact with the other Tasks works fine in 

105# the graph. 

106class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")): 

107 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

108 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

109 

110 

111class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections): 

112 conf1 = Field(dtype=int, default=1, doc="dummy config") 

113 

114 

115class Dummy4PipelineTask(PipelineTask): 

116 ConfigClass = Dummy4Config 

117 

118 

119class QuantumGraphTestCase(unittest.TestCase): 

120 """Tests the various functions of a quantum graph""" 

121 

122 def setUp(self): 

123 config = Config( 

124 { 

125 "version": 1, 

126 "skypix": { 

127 "common": "htm7", 

128 "htm": { 

129 "class": "lsst.sphgeom.HtmPixelization", 

130 "max_level": 24, 

131 }, 

132 }, 

133 "elements": { 

134 "A": { 

135 "keys": [ 

136 { 

137 "name": "id", 

138 "type": "int", 

139 } 

140 ], 

141 "storage": { 

142 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

143 }, 

144 }, 

145 "B": { 

146 "keys": [ 

147 { 

148 "name": "id", 

149 "type": "int", 

150 } 

151 ], 

152 "storage": { 

153 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

154 }, 

155 }, 

156 }, 

157 "packers": {}, 

158 } 

159 ) 

160 universe = DimensionUniverse(config=config) 

161 # need to make a mapping of TaskDef to set of quantum 

162 quantumMap = {} 

163 tasks = [] 

164 for task, label in ( 

165 (Dummy1PipelineTask, "R"), 

166 (Dummy2PipelineTask, "S"), 

167 (Dummy3PipelineTask, "T"), 

168 (Dummy4PipelineTask, "U"), 

169 ): 

170 config = task.ConfigClass() 

171 taskDef = TaskDef(get_full_type_name(task), config, task, label) 

172 tasks.append(taskDef) 

173 quantumSet = set() 

174 connections = taskDef.connections 

175 for a, b in ((1, 2), (3, 4)): 

176 if connections.initInputs: 

177 initInputDSType = DatasetType( 

178 connections.initInput.name, 

179 tuple(), 

180 storageClass=connections.initInput.storageClass, 

181 universe=universe, 

182 ) 

183 initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))] 

184 else: 

185 initRefs = None 

186 inputDSType = DatasetType( 

187 connections.input.name, 

188 connections.input.dimensions, 

189 storageClass=connections.input.storageClass, 

190 universe=universe, 

191 ) 

192 inputRefs = [ 

193 DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) 

194 ] 

195 outputDSType = DatasetType( 

196 connections.output.name, 

197 connections.output.dimensions, 

198 storageClass=connections.output.storageClass, 

199 universe=universe, 

200 ) 

201 outputRefs = [ 

202 DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) 

203 ] 

204 quantumSet.add( 

205 Quantum( 

206 taskName=task.__qualname__, 

207 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

208 taskClass=task, 

209 initInputs=initRefs, 

210 inputs={inputDSType: inputRefs}, 

211 outputs={outputDSType: outputRefs}, 

212 ) 

213 ) 

214 quantumMap[taskDef] = quantumSet 

215 self.tasks = tasks 

216 self.quantumMap = quantumMap 

217 self.qGraph = QuantumGraph(quantumMap, metadata=METADATA) 

218 self.universe = universe 

219 

220 def testTaskGraph(self): 

221 for taskDef in self.quantumMap.keys(): 

222 self.assertIn(taskDef, self.qGraph.taskGraph) 

223 

224 def testGraph(self): 

225 graphSet = {q.quantum for q in self.qGraph.graph} 

226 for quantum in chain.from_iterable(self.quantumMap.values()): 

227 self.assertIn(quantum, graphSet) 

228 

229 def testGetQuantumNodeByNodeId(self): 

230 inputQuanta = tuple(self.qGraph.inputQuanta) 

231 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

232 self.assertEqual(node, inputQuanta[0]) 

233 wrongNode = uuid.uuid4() 

234 with self.assertRaises(KeyError): 

235 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

236 

237 def testPickle(self): 

238 stringify = pickle.dumps(self.qGraph) 

239 restore: QuantumGraph = pickle.loads(stringify) 

240 self.assertEqual(self.qGraph, restore) 

241 

242 def testInputQuanta(self): 

243 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

244 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs) 

245 

246 def testOutputtQuanta(self): 

247 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

248 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs) 

249 

250 def testLength(self): 

251 self.assertEqual(len(self.qGraph), 2 * len(self.tasks)) 

252 

253 def testGetQuantaForTask(self): 

254 for task in self.tasks: 

255 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

256 

257 def testGetNodesForTask(self): 

258 for task in self.tasks: 

259 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task) 

260 quanta_in_node = set(n.quantum for n in nodes) 

261 self.assertEqual(quanta_in_node, self.quantumMap[task]) 

262 

263 def testFindTasksWithInput(self): 

264 self.assertEqual( 

265 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1] 

266 ) 

267 

268 def testFindTasksWithOutput(self): 

269 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

270 

271 def testTaskWithDSType(self): 

272 self.assertEqual( 

273 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2]) 

274 ) 

275 

276 def testFindTaskDefByName(self): 

277 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0]) 

278 

279 def testFindTaskDefByLabel(self): 

280 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0]) 

281 

282 def testFindQuantaWIthDSType(self): 

283 self.assertEqual( 

284 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]] 

285 ) 

286 

287 def testAllDatasetTypes(self): 

288 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

289 truth = set() 

290 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections): 

291 for connection in conClass.allConnections.values(): # type: ignore 

292 if not isinstance(connection, cT.InitOutput): 

293 truth.add(connection.name) 

294 self.assertEqual(allDatasetTypes, truth) 

295 

296 def testSubset(self): 

297 allNodes = list(self.qGraph) 

298 subset = self.qGraph.subset(allNodes[0]) 

299 self.assertEqual(len(subset), 1) 

300 subsetList = list(subset) 

301 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) 

302 self.assertEqual(self.qGraph._buildId, subset._buildId) 

303 

304 def testSubsetToConnected(self): 

305 # False because there are two quantum chains for two distinct sets of 

306 # dimensions 

307 self.assertFalse(self.qGraph.isConnected) 

308 

309 connectedGraphs = self.qGraph.subsetToConnected() 

310 self.assertEqual(len(connectedGraphs), 4) 

311 self.assertTrue(connectedGraphs[0].isConnected) 

312 self.assertTrue(connectedGraphs[1].isConnected) 

313 self.assertTrue(connectedGraphs[2].isConnected) 

314 self.assertTrue(connectedGraphs[3].isConnected) 

315 

316 # Split out task[3] because it is expected to be on its own 

317 for cg in connectedGraphs: 

318 if self.tasks[3] in cg.taskGraph: 

319 self.assertEqual(len(cg), 1) 

320 else: 

321 self.assertEqual(len(cg), 3) 

322 

323 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

324 

325 count = 0 

326 for node in self.qGraph: 

327 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

328 count += 1 

329 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

330 count += 1 

331 if connectedGraphs[2].checkQuantumInGraph(node.quantum): 

332 count += 1 

333 if connectedGraphs[3].checkQuantumInGraph(node.quantum): 

334 count += 1 

335 self.assertEqual(len(self.qGraph), count) 

336 

337 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs} 

338 for setLen, tskSet in taskSets.items(): 

339 if setLen == 3: 

340 self.assertEqual(set(self.tasks[:-1]), tskSet) 

341 elif setLen == 1: 

342 self.assertEqual({self.tasks[-1]}, tskSet) 

343 for cg in connectedGraphs: 

344 if len(cg.taskGraph) == 1: 

345 continue 

346 allNodes = list(cg) 

347 node = cg.determineInputsToQuantumNode(allNodes[1]) 

348 self.assertEqual(set([allNodes[0]]), node) 

349 node = cg.determineInputsToQuantumNode(allNodes[1]) 

350 self.assertEqual(set([allNodes[0]]), node) 

351 

352 def testDetermineOutputsOfQuantumNode(self): 

353 testNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

354 matchNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

355 connections = set() 

356 for node in testNodes: 

357 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node)) 

358 self.assertEqual(matchNodes, connections) 

359 

360 def testDetermineConnectionsOfQuantum(self): 

361 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

362 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2]) 

363 # outputs contain nodes tested for because it is a complete graph 

364 matchNodes |= set(testNodes) 

365 connections = set() 

366 for node in testNodes: 

367 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node)) 

368 self.assertEqual(matchNodes, connections) 

369 

370 def testDetermineAnsestorsOfQuantumNode(self): 

371 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

372 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

373 matchNodes |= set(testNodes) 

374 connections = set() 

375 for node in testNodes: 

376 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node)) 

377 self.assertEqual(matchNodes, connections) 

378 

379 def testFindCycle(self): 

380 self.assertFalse(self.qGraph.findCycle()) 

381 

382 def testSaveLoad(self): 

383 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: 

384 self.qGraph.save(tmpFile) 

385 tmpFile.seek(0) 

386 restore = QuantumGraph.load(tmpFile, self.universe) 

387 self.assertEqual(self.qGraph, restore) 

388 # Load in just one node 

389 tmpFile.seek(0) 

390 nodeId = [n.nodeId for n in self.qGraph][0] 

391 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,)) 

392 self.assertEqual(len(restoreSub), 1) 

393 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId)) 

394 

395 def testSaveLoadUri(self): 

396 uri = None 

397 try: 

398 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: 

399 uri = tmpFile.name 

400 self.qGraph.saveUri(uri) 

401 restore = QuantumGraph.loadUri(uri, self.universe) 

402 self.assertEqual(restore.metadata, METADATA) 

403 self.assertEqual(self.qGraph, restore) 

404 nodeNumberId = random.randint(0, len(self.qGraph) - 1) 

405 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId] 

406 restoreSub = QuantumGraph.loadUri( 

407 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId 

408 ) 

409 self.assertEqual(len(restoreSub), 1) 

410 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber)) 

411 # verify that more than one node works 

412 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

413 # ensure it is a different node number 

414 while nodeNumberId2 == nodeNumberId: 

415 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

416 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2] 

417 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) 

418 self.assertEqual(len(restoreSub), 2) 

419 self.assertEqual( 

420 set(restoreSub), 

421 set( 

422 ( 

423 restore.getQuantumNodeByNodeId(nodeNumber), 

424 restore.getQuantumNodeByNodeId(nodeNumber2), 

425 ) 

426 ), 

427 ) 

428 # verify an error when requesting a non existant node number 

429 with self.assertRaises(ValueError): 

430 QuantumGraph.loadUri(uri, self.universe, nodes=(99,)) 

431 

432 # verify a graphID that does not match will be an error 

433 with self.assertRaises(ValueError): 

434 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT") 

435 

436 except Exception as e: 

437 raise e 

438 finally: 

439 if uri is not None: 

440 os.remove(uri) 

441 

442 with self.assertRaises(TypeError): 

443 self.qGraph.saveUri("test.notgraph") 

444 

445 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!") 

446 @mock_s3 

447 def testSaveLoadUriS3(self): 

448 # Test loading a quantum graph from an mock s3 store 

449 conn = boto3.resource("s3", region_name="us-east-1") 

450 conn.create_bucket(Bucket="testBucket") 

451 uri = "s3://testBucket/qgraph.qgraph" 

452 self.qGraph.saveUri(uri) 

453 restore = QuantumGraph.loadUri(uri, self.universe) 

454 self.assertEqual(self.qGraph, restore) 

455 nodeId = list(self.qGraph)[0].nodeId 

456 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeId,)) 

457 self.assertEqual(len(restoreSub), 1) 

458 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId)) 

459 

460 def testContains(self): 

461 firstNode = next(iter(self.qGraph)) 

462 self.assertIn(firstNode, self.qGraph) 

463 

464 

465class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

466 pass 

467 

468 

469def setup_module(module): 

470 lsst.utils.tests.init() 

471 

472 

473if __name__ == "__main__": 473 ↛ 474line 473 didn't jump to line 474, because the condition on line 473 was never true

474 lsst.utils.tests.init() 

475 unittest.main()