Coverage for tests/test_quantumGraph.py: 30%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

284 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from itertools import chain 

23import os 

24import pickle 

25import tempfile 

26from typing import Iterable 

27import unittest 

28import uuid 

29import random 

30from lsst.daf.butler import DimensionUniverse 

31 

32from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, 

33 DatasetTypeName) 

34import lsst.pipe.base.connectionTypes as cT 

35from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config 

36from lsst.pex.config import Field 

37from lsst.pipe.base.graph.quantumNode import QuantumNode 

38import lsst.utils.tests 

39 

40try: 

41 import boto3 

42 from moto import mock_s3 

43except ImportError: 

44 boto3 = None 

45 

46 def mock_s3(cls): 

47 """A no-op decorator in case moto mock_s3 can not be imported. 

48 """ 

49 return cls 

50 

51METADATA = {'a': [1, 2, 3]} 

52 

53 

54class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

55 initOutput = cT.InitOutput(name="Dummy1InitOutput", 

56 storageClass="ExposureF", 

57 doc="n/a") 

58 input = cT.Input(name="Dummy1Input", 

59 storageClass="ExposureF", 

60 doc="n/a", 

61 dimensions=("A", "B")) 

62 output = cT.Output(name="Dummy1Output", 

63 storageClass="ExposureF", 

64 doc="n/a", 

65 dimensions=("A", "B")) 

66 

67 

68class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

69 conf1 = Field(dtype=int, default=1, doc="dummy config") 

70 

71 

72class Dummy1PipelineTask(PipelineTask): 

73 ConfigClass = Dummy1Config 

74 

75 

76class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

77 initInput = cT.InitInput(name="Dummy1InitOutput", 

78 storageClass="ExposureF", 

79 doc="n/a") 

80 initOutput = cT.InitOutput(name="Dummy2InitOutput", 

81 storageClass="ExposureF", 

82 doc="n/a") 

83 input = cT.Input(name="Dummy1Output", 

84 storageClass="ExposureF", 

85 doc="n/a", 

86 dimensions=("A", "B")) 

87 output = cT.Output(name="Dummy2Output", 

88 storageClass="ExposureF", 

89 doc="n/a", 

90 dimensions=("A", "B")) 

91 

92 

93class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

94 conf1 = Field(dtype=int, default=1, doc="dummy config") 

95 

96 

97class Dummy2PipelineTask(PipelineTask): 

98 ConfigClass = Dummy2Config 

99 

100 

101class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

102 initInput = cT.InitInput(name="Dummy2InitOutput", 

103 storageClass="ExposureF", 

104 doc="n/a") 

105 initOutput = cT.InitOutput(name="Dummy3InitOutput", 

106 storageClass="ExposureF", 

107 doc="n/a") 

108 input = cT.Input(name="Dummy2Output", 

109 storageClass="ExposureF", 

110 doc="n/a", 

111 dimensions=("A", "B")) 

112 output = cT.Output(name="Dummy3Output", 

113 storageClass="ExposureF", 

114 doc="n/a", 

115 dimensions=("A", "B")) 

116 

117 

118class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

119 conf1 = Field(dtype=int, default=1, doc="dummy config") 

120 

121 

122class Dummy3PipelineTask(PipelineTask): 

123 ConfigClass = Dummy3Config 

124 

125 

126# Test if a Task that does not interact with the other Tasks works fine in 

127# the graph. 

128class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")): 

129 input = cT.Input(name="Dummy4Input", 

130 storageClass="ExposureF", 

131 doc="n/a", 

132 dimensions=("A", "B")) 

133 output = cT.Output(name="Dummy4Output", 

134 storageClass="ExposureF", 

135 doc="n/a", 

136 dimensions=("A", "B")) 

137 

138 

139class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections): 

140 conf1 = Field(dtype=int, default=1, doc="dummy config") 

141 

142 

143class Dummy4PipelineTask(PipelineTask): 

144 ConfigClass = Dummy4Config 

145 

146 

147class QuantumGraphTestCase(unittest.TestCase): 

148 """Tests the various functions of a quantum graph 

149 """ 

150 def setUp(self): 

151 config = Config({ 

152 "version": 1, 

153 "skypix": { 

154 "common": "htm7", 

155 "htm": { 

156 "class": "lsst.sphgeom.HtmPixelization", 

157 "max_level": 24, 

158 } 

159 }, 

160 "elements": { 

161 "A": { 

162 "keys": [{ 

163 "name": "id", 

164 "type": "int", 

165 }], 

166 "storage": { 

167 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

168 }, 

169 }, 

170 "B": { 

171 "keys": [{ 

172 "name": "id", 

173 "type": "int", 

174 }], 

175 "storage": { 

176 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

177 }, 

178 } 

179 }, 

180 "packers": {} 

181 }) 

182 universe = DimensionUniverse(config=config) 

183 # need to make a mapping of TaskDef to set of quantum 

184 quantumMap = {} 

185 tasks = [] 

186 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T"), 

187 (Dummy4PipelineTask, "U")): 

188 config = task.ConfigClass() 

189 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label) 

190 tasks.append(taskDef) 

191 quantumSet = set() 

192 connections = taskDef.connections 

193 for a, b in ((1, 2), (3, 4)): 

194 if connections.initInputs: 

195 initInputDSType = DatasetType(connections.initInput.name, 

196 tuple(), 

197 storageClass=connections.initInput.storageClass, 

198 universe=universe) 

199 initRefs = [DatasetRef(initInputDSType, 

200 DataCoordinate.makeEmpty(universe))] 

201 else: 

202 initRefs = None 

203 inputDSType = DatasetType(connections.input.name, 

204 connections.input.dimensions, 

205 storageClass=connections.input.storageClass, 

206 universe=universe, 

207 ) 

208 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

209 universe=universe))] 

210 outputDSType = DatasetType(connections.output.name, 

211 connections.output.dimensions, 

212 storageClass=connections.output.storageClass, 

213 universe=universe, 

214 ) 

215 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

216 universe=universe))] 

217 quantumSet.add( 

218 Quantum(taskName=task.__qualname__, 

219 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

220 taskClass=task, 

221 initInputs=initRefs, 

222 inputs={inputDSType: inputRefs}, 

223 outputs={outputDSType: outputRefs} 

224 ) 

225 ) 

226 quantumMap[taskDef] = quantumSet 

227 self.tasks = tasks 

228 self.quantumMap = quantumMap 

229 self.qGraph = QuantumGraph(quantumMap, metadata=METADATA) 

230 self.universe = universe 

231 

232 def _cleanGraphs(self, graph1, graph2): 

233 # This is a hack for the unit test since the qualified name will be 

234 # different as it will be __main__ here, but qualified to the 

235 # unittest module name when restored 

236 # Updates in place 

237 for saved, loaded in zip(graph1.taskGraph, 

238 graph2.taskGraph): 

239 saved.taskName = saved.taskName.split('.')[-1] 

240 loaded.taskName = loaded.taskName.split('.')[-1] 

241 

242 def testTaskGraph(self): 

243 for taskDef in self.quantumMap.keys(): 

244 self.assertIn(taskDef, self.qGraph.taskGraph) 

245 

246 def testGraph(self): 

247 graphSet = {q.quantum for q in self.qGraph.graph} 

248 for quantum in chain.from_iterable(self.quantumMap.values()): 

249 self.assertIn(quantum, graphSet) 

250 

251 def testGetQuantumNodeByNodeId(self): 

252 inputQuanta = tuple(self.qGraph.inputQuanta) 

253 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

254 self.assertEqual(node, inputQuanta[0]) 

255 wrongNode = uuid.uuid4() 

256 with self.assertRaises(KeyError): 

257 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

258 

259 def testPickle(self): 

260 stringify = pickle.dumps(self.qGraph) 

261 restore: QuantumGraph = pickle.loads(stringify) 

262 self._cleanGraphs(self.qGraph, restore) 

263 self.assertEqual(self.qGraph, restore) 

264 

265 def testInputQuanta(self): 

266 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

267 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs) 

268 

269 def testOutputtQuanta(self): 

270 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

271 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs) 

272 

273 def testLength(self): 

274 self.assertEqual(len(self.qGraph), 2*len(self.tasks)) 

275 

276 def testGetQuantaForTask(self): 

277 for task in self.tasks: 

278 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

279 

280 def testGetNodesForTask(self): 

281 for task in self.tasks: 

282 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task) 

283 quanta_in_node = set(n.quantum for n in nodes) 

284 self.assertEqual(quanta_in_node, self.quantumMap[task]) 

285 

286 def testFindTasksWithInput(self): 

287 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], 

288 self.tasks[1]) 

289 

290 def testFindTasksWithOutput(self): 

291 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

292 

293 def testTaskWithDSType(self): 

294 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), 

295 set(self.tasks[:2])) 

296 

297 def testFindTaskDefByName(self): 

298 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], 

299 self.tasks[0]) 

300 

301 def testFindTaskDefByLabel(self): 

302 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), 

303 self.tasks[0]) 

304 

305 def testFindQuantaWIthDSType(self): 

306 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), 

307 self.quantumMap[self.tasks[0]]) 

308 

309 def testAllDatasetTypes(self): 

310 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

311 truth = set() 

312 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections): 

313 for connection in conClass.allConnections.values(): # type: ignore 

314 if not isinstance(connection, cT.InitOutput): 

315 truth.add(connection.name) 

316 self.assertEqual(allDatasetTypes, truth) 

317 

318 def testSubset(self): 

319 allNodes = list(self.qGraph) 

320 subset = self.qGraph.subset(allNodes[0]) 

321 self.assertEqual(len(subset), 1) 

322 subsetList = list(subset) 

323 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) 

324 self.assertEqual(self.qGraph._buildId, subset._buildId) 

325 

326 def testIsConnected(self): 

327 # False because there are two quantum chains for two distinct sets of 

328 # dimensions 

329 self.assertFalse(self.qGraph.isConnected) 

330 # make a broken subset 

331 filteredNodes = [n for n in self.qGraph if n.taskDef.label != 'U'] 

332 subset = self.qGraph.subset((filteredNodes[0], filteredNodes[1])) 

333 # True because we subset to only one chain of graphs 

334 self.assertTrue(subset.isConnected) 

335 

336 def testSubsetToConnected(self): 

337 connectedGraphs = self.qGraph.subsetToConnected() 

338 self.assertEqual(len(connectedGraphs), 4) 

339 self.assertTrue(connectedGraphs[0].isConnected) 

340 self.assertTrue(connectedGraphs[1].isConnected) 

341 self.assertTrue(connectedGraphs[2].isConnected) 

342 self.assertTrue(connectedGraphs[3].isConnected) 

343 

344 # Split out task[3] because it is expected to be on its own 

345 for cg in connectedGraphs: 

346 if self.tasks[3] in cg.taskGraph: 

347 self.assertEqual(len(cg), 1) 

348 else: 

349 self.assertEqual(len(cg), 3) 

350 

351 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

352 

353 count = 0 

354 for node in self.qGraph: 

355 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

356 count += 1 

357 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

358 count += 1 

359 if connectedGraphs[2].checkQuantumInGraph(node.quantum): 

360 count += 1 

361 if connectedGraphs[3].checkQuantumInGraph(node.quantum): 

362 count += 1 

363 self.assertEqual(len(self.qGraph), count) 

364 

365 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs} 

366 for setLen, tskSet in taskSets.items(): 

367 if setLen == 3: 

368 self.assertEqual(set(self.tasks[:-1]), tskSet) 

369 elif setLen == 1: 

370 self.assertEqual({self.tasks[-1]}, tskSet) 

371 for cg in connectedGraphs: 

372 if len(cg.taskGraph) == 1: 

373 continue 

374 allNodes = list(cg) 

375 node = cg.determineInputsToQuantumNode(allNodes[1]) 

376 self.assertEqual(set([allNodes[0]]), node) 

377 node = cg.determineInputsToQuantumNode(allNodes[1]) 

378 self.assertEqual(set([allNodes[0]]), node) 

379 

380 def testDetermineOutputsOfQuantumNode(self): 

381 testNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

382 matchNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

383 connections = set() 

384 for node in testNodes: 

385 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node)) 

386 self.assertEqual(matchNodes, connections) 

387 

388 def testDetermineConnectionsOfQuantum(self): 

389 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

390 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2]) 

391 # outputs contain nodes tested for because it is a complete graph 

392 matchNodes |= set(testNodes) 

393 connections = set() 

394 for node in testNodes: 

395 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node)) 

396 self.assertEqual(matchNodes, connections) 

397 

398 def testDetermineAnsestorsOfQuantumNode(self): 

399 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

400 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

401 matchNodes |= set(testNodes) 

402 connections = set() 

403 for node in testNodes: 

404 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node)) 

405 self.assertEqual(matchNodes, connections) 

406 

407 def testFindCycle(self): 

408 self.assertFalse(self.qGraph.findCycle()) 

409 

410 def testSaveLoad(self): 

411 with tempfile.TemporaryFile(suffix='.qgraph') as tmpFile: 

412 self.qGraph.save(tmpFile) 

413 tmpFile.seek(0) 

414 restore = QuantumGraph.load(tmpFile, self.universe) 

415 self._cleanGraphs(self.qGraph, restore) 

416 self.assertEqual(self.qGraph, restore) 

417 # Load in just one node 

418 tmpFile.seek(0) 

419 nodeId = [n.nodeId for n in self.qGraph][0] 

420 restoreSub = QuantumGraph.load(tmpFile, self.universe, 

421 nodes=(nodeId,)) 

422 self.assertEqual(len(restoreSub), 1) 

423 self.assertEqual(list(restoreSub)[0], 

424 restore.getQuantumNodeByNodeId(nodeId)) 

425 

426 def testSaveLoadUri(self): 

427 uri = None 

428 try: 

429 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: 

430 uri = tmpFile.name 

431 self.qGraph.saveUri(uri) 

432 restore = QuantumGraph.loadUri(uri, self.universe) 

433 self.assertEqual(restore.metadata, METADATA) 

434 self._cleanGraphs(self.qGraph, restore) 

435 self.assertEqual(self.qGraph, restore) 

436 nodeNumberId = random.randint(0, len(self.qGraph)-1) 

437 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId] 

438 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber,), 

439 graphID=self.qGraph._buildId) 

440 self.assertEqual(len(restoreSub), 1) 

441 self.assertEqual(list(restoreSub)[0], 

442 restore.getQuantumNodeByNodeId(nodeNumber)) 

443 # verify that more than one node works 

444 nodeNumberId2 = random.randint(0, len(self.qGraph)-1) 

445 # ensure it is a different node number 

446 while nodeNumberId2 == nodeNumberId: 

447 nodeNumberId2 = random.randint(0, len(self.qGraph)-1) 

448 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2] 

449 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) 

450 self.assertEqual(len(restoreSub), 2) 

451 self.assertEqual(set(restoreSub), 

452 set((restore.getQuantumNodeByNodeId(nodeNumber), 

453 restore.getQuantumNodeByNodeId(nodeNumber2)))) 

454 # verify an error when requesting a non existant node number 

455 with self.assertRaises(ValueError): 

456 QuantumGraph.loadUri(uri, self.universe, nodes=(99,)) 

457 

458 # verify a graphID that does not match will be an error 

459 with self.assertRaises(ValueError): 

460 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT") 

461 

462 except Exception as e: 

463 raise e 

464 finally: 

465 if uri is not None: 

466 os.remove(uri) 

467 

468 with self.assertRaises(TypeError): 

469 self.qGraph.saveUri("test.notgraph") 

470 

471 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!") 

472 @mock_s3 

473 def testSaveLoadUriS3(self): 

474 # Test loading a quantum graph from an mock s3 store 

475 conn = boto3.resource('s3', region_name="us-east-1") 

476 conn.create_bucket(Bucket='testBucket') 

477 uri = "s3://testBucket/qgraph.qgraph" 

478 self.qGraph.saveUri(uri) 

479 restore = QuantumGraph.loadUri(uri, self.universe) 

480 self._cleanGraphs(self.qGraph, restore) 

481 self.assertEqual(self.qGraph, restore) 

482 nodeId = list(self.qGraph)[0].nodeId 

483 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeId,)) 

484 self.assertEqual(len(restoreSub), 1) 

485 self.assertEqual(list(restoreSub)[0], 

486 restore.getQuantumNodeByNodeId(nodeId)) 

487 

488 def testContains(self): 

489 firstNode = next(iter(self.qGraph)) 

490 self.assertIn(firstNode, self.qGraph) 

491 

492 

493class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

494 pass 

495 

496 

497def setup_module(module): 

498 lsst.utils.tests.init() 

499 

500 

501if __name__ == "__main__": 501 ↛ 502line 501 didn't jump to line 502, because the condition on line 501 was never true

502 lsst.utils.tests.init() 

503 unittest.main()