Coverage for tests/test_quantumGraph.py: 27%

306 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-09-11 01:21 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import os 

23import pickle 

24import random 

25import tempfile 

26import unittest 

27import uuid 

28from itertools import chain 

29from typing import Iterable 

30 

31import lsst.pipe.base.connectionTypes as cT 

32import lsst.utils.tests 

33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

34from lsst.pex.config import Field 

35from lsst.pipe.base import ( 

36 DatasetTypeName, 

37 PipelineTask, 

38 PipelineTaskConfig, 

39 PipelineTaskConnections, 

40 QuantumGraph, 

41 TaskDef, 

42) 

43from lsst.pipe.base.graph.quantumNode import QuantumNode 

44from lsst.utils.introspection import get_full_type_name 

45 

46try: 

47 import boto3 

48 from moto import mock_s3 

49except ImportError: 

50 boto3 = None 

51 

52 def mock_s3(cls): 

53 """A no-op decorator in case moto mock_s3 can not be imported.""" 

54 return cls 

55 

56 

57METADATA = {"a": [1, 2, 3]} 

58 

59 

60class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

61 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

62 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

63 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

64 

65 

66class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

67 conf1 = Field(dtype=int, default=1, doc="dummy config") 

68 

69 

70class Dummy1PipelineTask(PipelineTask): 

71 ConfigClass = Dummy1Config 

72 

73 

74class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

75 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

76 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

77 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

78 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

79 

80 

81class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

82 conf1 = Field(dtype=int, default=1, doc="dummy config") 

83 

84 

85class Dummy2PipelineTask(PipelineTask): 

86 ConfigClass = Dummy2Config 

87 

88 

89class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

90 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

91 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a") 

92 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

93 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

94 

95 

96class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

97 conf1 = Field(dtype=int, default=1, doc="dummy config") 

98 

99 

100class Dummy3PipelineTask(PipelineTask): 

101 ConfigClass = Dummy3Config 

102 

103 

104# Test if a Task that does not interact with the other Tasks works fine in 

105# the graph. 

106class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")): 

107 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

108 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

109 

110 

111class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections): 

112 conf1 = Field(dtype=int, default=1, doc="dummy config") 

113 

114 

115class Dummy4PipelineTask(PipelineTask): 

116 ConfigClass = Dummy4Config 

117 

118 

119class QuantumGraphTestCase(unittest.TestCase): 

120 """Tests the various functions of a quantum graph""" 

121 

122 def setUp(self): 

123 self.config = Config( 

124 { 

125 "version": 1, 

126 "namespace": "pipe_base_test", 

127 "skypix": { 

128 "common": "htm7", 

129 "htm": { 

130 "class": "lsst.sphgeom.HtmPixelization", 

131 "max_level": 24, 

132 }, 

133 }, 

134 "elements": { 

135 "A": { 

136 "keys": [ 

137 { 

138 "name": "id", 

139 "type": "int", 

140 } 

141 ], 

142 "storage": { 

143 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

144 }, 

145 }, 

146 "B": { 

147 "keys": [ 

148 { 

149 "name": "id", 

150 "type": "int", 

151 } 

152 ], 

153 "storage": { 

154 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

155 }, 

156 }, 

157 }, 

158 "packers": {}, 

159 } 

160 ) 

161 universe = DimensionUniverse(config=self.config) 

162 

163 def _makeDatasetType(connection): 

164 return DatasetType( 

165 connection.name, 

166 getattr(connection, "dimensions", ()), 

167 storageClass=connection.storageClass, 

168 universe=universe, 

169 ) 

170 

171 # need to make a mapping of TaskDef to set of quantum 

172 quantumMap = {} 

173 tasks = [] 

174 initInputs = {} 

175 initOutputs = {} 

176 for task, label in ( 

177 (Dummy1PipelineTask, "R"), 

178 (Dummy2PipelineTask, "S"), 

179 (Dummy3PipelineTask, "T"), 

180 (Dummy4PipelineTask, "U"), 

181 ): 

182 config = task.ConfigClass() 

183 taskDef = TaskDef(get_full_type_name(task), config, task, label) 

184 tasks.append(taskDef) 

185 quantumSet = set() 

186 connections = taskDef.connections 

187 if connections.initInputs: 

188 initInputDSType = _makeDatasetType(connections.initInput) 

189 initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))] 

190 initInputs[taskDef] = initRefs 

191 else: 

192 initRefs = None 

193 if connections.initOutputs: 

194 initOutputDSType = _makeDatasetType(connections.initOutput) 

195 initRefs = [DatasetRef(initOutputDSType, DataCoordinate.makeEmpty(universe))] 

196 initOutputs[taskDef] = initRefs 

197 inputDSType = _makeDatasetType(connections.input) 

198 outputDSType = _makeDatasetType(connections.output) 

199 for a, b in ((1, 2), (3, 4)): 

200 inputRefs = [ 

201 DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) 

202 ] 

203 outputRefs = [ 

204 DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) 

205 ] 

206 quantumSet.add( 

207 Quantum( 

208 taskName=task.__qualname__, 

209 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

210 taskClass=task, 

211 initInputs=initRefs, 

212 inputs={inputDSType: inputRefs}, 

213 outputs={outputDSType: outputRefs}, 

214 ) 

215 ) 

216 quantumMap[taskDef] = quantumSet 

217 self.tasks = tasks 

218 self.quantumMap = quantumMap 

219 self.qGraph = QuantumGraph( 

220 quantumMap, metadata=METADATA, universe=universe, initInputs=initInputs, initOutputs=initOutputs 

221 ) 

222 self.universe = universe 

223 

224 def testTaskGraph(self): 

225 for taskDef in self.quantumMap.keys(): 

226 self.assertIn(taskDef, self.qGraph.taskGraph) 

227 

228 def testGraph(self): 

229 graphSet = {q.quantum for q in self.qGraph.graph} 

230 for quantum in chain.from_iterable(self.quantumMap.values()): 

231 self.assertIn(quantum, graphSet) 

232 

233 def testGetQuantumNodeByNodeId(self): 

234 inputQuanta = tuple(self.qGraph.inputQuanta) 

235 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

236 self.assertEqual(node, inputQuanta[0]) 

237 wrongNode = uuid.uuid4() 

238 with self.assertRaises(KeyError): 

239 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

240 

241 def testPickle(self): 

242 stringify = pickle.dumps(self.qGraph) 

243 restore: QuantumGraph = pickle.loads(stringify) 

244 self.assertEqual(self.qGraph, restore) 

245 

246 def testInputQuanta(self): 

247 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

248 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs) 

249 

250 def testOutputtQuanta(self): 

251 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

252 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs) 

253 

254 def testLength(self): 

255 self.assertEqual(len(self.qGraph), 2 * len(self.tasks)) 

256 

257 def testGetQuantaForTask(self): 

258 for task in self.tasks: 

259 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

260 

261 def testGetNodesForTask(self): 

262 for task in self.tasks: 

263 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task) 

264 quanta_in_node = set(n.quantum for n in nodes) 

265 self.assertEqual(quanta_in_node, self.quantumMap[task]) 

266 

267 def testFindTasksWithInput(self): 

268 self.assertEqual( 

269 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1] 

270 ) 

271 

272 def testFindTasksWithOutput(self): 

273 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

274 

275 def testTaskWithDSType(self): 

276 self.assertEqual( 

277 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2]) 

278 ) 

279 

280 def testFindTaskDefByName(self): 

281 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0]) 

282 

283 def testFindTaskDefByLabel(self): 

284 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0]) 

285 

286 def testFindQuantaWIthDSType(self): 

287 self.assertEqual( 

288 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]] 

289 ) 

290 

291 def testAllDatasetTypes(self): 

292 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

293 truth = set() 

294 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections): 

295 for connection in conClass.allConnections.values(): # type: ignore 

296 if not isinstance(connection, cT.InitOutput): 

297 truth.add(connection.name) 

298 self.assertEqual(allDatasetTypes, truth) 

299 

300 def testSubset(self): 

301 allNodes = list(self.qGraph) 

302 subset = self.qGraph.subset(allNodes[0]) 

303 self.assertEqual(len(subset), 1) 

304 subsetList = list(subset) 

305 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) 

306 self.assertEqual(self.qGraph._buildId, subset._buildId) 

307 

308 def testSubsetToConnected(self): 

309 # False because there are two quantum chains for two distinct sets of 

310 # dimensions 

311 self.assertFalse(self.qGraph.isConnected) 

312 

313 connectedGraphs = self.qGraph.subsetToConnected() 

314 self.assertEqual(len(connectedGraphs), 4) 

315 self.assertTrue(connectedGraphs[0].isConnected) 

316 self.assertTrue(connectedGraphs[1].isConnected) 

317 self.assertTrue(connectedGraphs[2].isConnected) 

318 self.assertTrue(connectedGraphs[3].isConnected) 

319 

320 # Split out task[3] because it is expected to be on its own 

321 for cg in connectedGraphs: 

322 if self.tasks[3] in cg.taskGraph: 

323 self.assertEqual(len(cg), 1) 

324 else: 

325 self.assertEqual(len(cg), 3) 

326 

327 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

328 

329 count = 0 

330 for node in self.qGraph: 

331 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

332 count += 1 

333 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

334 count += 1 

335 if connectedGraphs[2].checkQuantumInGraph(node.quantum): 

336 count += 1 

337 if connectedGraphs[3].checkQuantumInGraph(node.quantum): 

338 count += 1 

339 self.assertEqual(len(self.qGraph), count) 

340 

341 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs} 

342 for setLen, tskSet in taskSets.items(): 

343 if setLen == 3: 

344 self.assertEqual(set(self.tasks[:-1]), tskSet) 

345 elif setLen == 1: 

346 self.assertEqual({self.tasks[-1]}, tskSet) 

347 for cg in connectedGraphs: 

348 if len(cg.taskGraph) == 1: 

349 continue 

350 allNodes = list(cg) 

351 node = cg.determineInputsToQuantumNode(allNodes[1]) 

352 self.assertEqual(set([allNodes[0]]), node) 

353 node = cg.determineInputsToQuantumNode(allNodes[1]) 

354 self.assertEqual(set([allNodes[0]]), node) 

355 

356 def testDetermineOutputsOfQuantumNode(self): 

357 testNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

358 matchNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

359 connections = set() 

360 for node in testNodes: 

361 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node)) 

362 self.assertEqual(matchNodes, connections) 

363 

364 def testDetermineConnectionsOfQuantum(self): 

365 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

366 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2]) 

367 # outputs contain nodes tested for because it is a complete graph 

368 matchNodes |= set(testNodes) 

369 connections = set() 

370 for node in testNodes: 

371 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node)) 

372 self.assertEqual(matchNodes, connections) 

373 

374 def testDetermineAnsestorsOfQuantumNode(self): 

375 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

376 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

377 matchNodes |= set(testNodes) 

378 connections = set() 

379 for node in testNodes: 

380 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node)) 

381 self.assertEqual(matchNodes, connections) 

382 

383 def testFindCycle(self): 

384 self.assertFalse(self.qGraph.findCycle()) 

385 

386 def testSaveLoad(self): 

387 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: 

388 self.qGraph.save(tmpFile) 

389 tmpFile.seek(0) 

390 restore = QuantumGraph.load(tmpFile, self.universe) 

391 self.assertEqual(self.qGraph, restore) 

392 # Load in just one node 

393 tmpFile.seek(0) 

394 nodeId = [n.nodeId for n in self.qGraph][0] 

395 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,)) 

396 self.assertEqual(len(restoreSub), 1) 

397 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId)) 

398 # Check that InitInput and InitOutput refs are restored correctly. 

399 for taskDef in restore.iterTaskGraph(): 

400 if taskDef.label in ("S", "T"): 

401 refs = restore.initInputRefs(taskDef) 

402 self.assertIsNotNone(refs) 

403 self.assertGreater(len(refs), 0) 

404 if taskDef.label in ("R", "S", "T"): 

405 refs = restore.initOutputRefs(taskDef) 

406 self.assertIsNotNone(refs) 

407 self.assertGreater(len(refs), 0) 

408 

409 # Different universes. 

410 tmpFile.seek(0) 

411 different_config = self.config.copy() 

412 different_config["version"] = 1_000_000 

413 different_universe = DimensionUniverse(config=different_config) 

414 with self.assertLogs("lsst.daf.butler", "INFO"): 

415 QuantumGraph.load(tmpFile, different_universe) 

416 

417 different_config["namespace"] = "incompatible" 

418 different_universe = DimensionUniverse(config=different_config) 

419 print("Trying with uni ", different_universe) 

420 tmpFile.seek(0) 

421 with self.assertRaises(RuntimeError) as cm: 

422 QuantumGraph.load(tmpFile, different_universe) 

423 self.assertIn("not compatible with", str(cm.exception)) 

424 

425 def testSaveLoadUri(self): 

426 uri = None 

427 try: 

428 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: 

429 uri = tmpFile.name 

430 self.qGraph.saveUri(uri) 

431 restore = QuantumGraph.loadUri(uri) 

432 self.assertEqual(restore.metadata, METADATA) 

433 self.assertEqual(self.qGraph, restore) 

434 nodeNumberId = random.randint(0, len(self.qGraph) - 1) 

435 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId] 

436 restoreSub = QuantumGraph.loadUri( 

437 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId 

438 ) 

439 self.assertEqual(len(restoreSub), 1) 

440 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber)) 

441 # verify that more than one node works 

442 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

443 # ensure it is a different node number 

444 while nodeNumberId2 == nodeNumberId: 

445 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

446 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2] 

447 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) 

448 self.assertEqual(len(restoreSub), 2) 

449 self.assertEqual( 

450 set(restoreSub), 

451 set( 

452 ( 

453 restore.getQuantumNodeByNodeId(nodeNumber), 

454 restore.getQuantumNodeByNodeId(nodeNumber2), 

455 ) 

456 ), 

457 ) 

458 # verify an error when requesting a non existant node number 

459 with self.assertRaises(ValueError): 

460 QuantumGraph.loadUri(uri, self.universe, nodes=(99,)) 

461 

462 # verify a graphID that does not match will be an error 

463 with self.assertRaises(ValueError): 

464 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT") 

465 

466 except Exception as e: 

467 raise e 

468 finally: 

469 if uri is not None: 

470 os.remove(uri) 

471 

472 with self.assertRaises(TypeError): 

473 self.qGraph.saveUri("test.notgraph") 

474 

475 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!") 

476 @mock_s3 

477 def testSaveLoadUriS3(self): 

478 # Test loading a quantum graph from an mock s3 store 

479 conn = boto3.resource("s3", region_name="us-east-1") 

480 conn.create_bucket(Bucket="testBucket") 

481 uri = "s3://testBucket/qgraph.qgraph" 

482 self.qGraph.saveUri(uri) 

483 restore = QuantumGraph.loadUri(uri) 

484 self.assertEqual(self.qGraph, restore) 

485 nodeId = list(self.qGraph)[0].nodeId 

486 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeId,)) 

487 self.assertEqual(len(restoreSub), 1) 

488 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId)) 

489 

490 def testContains(self): 

491 firstNode = next(iter(self.qGraph)) 

492 self.assertIn(firstNode, self.qGraph) 

493 

494 def testDimensionUniverseInSave(self): 

495 _, header = self.qGraph._buildSaveObject(returnHeader=True) 

496 # type ignore because buildSaveObject does not have method overload 

497 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore 

498 

499 

500class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

501 pass 

502 

503 

504def setup_module(module): 

505 lsst.utils.tests.init() 

506 

507 

508if __name__ == "__main__": 508 ↛ 509line 508 didn't jump to line 509, because the condition on line 508 was never true

509 lsst.utils.tests.init() 

510 unittest.main()