Coverage for tests/test_quantumGraph.py: 29%

288 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-08-04 09:17 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import os 

23import pickle 

24import random 

25import tempfile 

26import unittest 

27import uuid 

28from itertools import chain 

29from typing import Iterable 

30 

31import lsst.pipe.base.connectionTypes as cT 

32import lsst.utils.tests 

33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

34from lsst.pex.config import Field 

35from lsst.pipe.base import ( 

36 DatasetTypeName, 

37 PipelineTask, 

38 PipelineTaskConfig, 

39 PipelineTaskConnections, 

40 QuantumGraph, 

41 TaskDef, 

42) 

43from lsst.pipe.base.graph.quantumNode import QuantumNode 

44from lsst.utils.introspection import get_full_type_name 

45 

46try: 

47 import boto3 

48 from moto import mock_s3 

49except ImportError: 

50 boto3 = None 

51 

52 def mock_s3(cls): 

53 """A no-op decorator in case moto mock_s3 can not be imported.""" 

54 return cls 

55 

56 

57METADATA = {"a": [1, 2, 3]} 

58 

59 

60class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

61 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

62 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

63 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

64 

65 

66class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

67 conf1 = Field(dtype=int, default=1, doc="dummy config") 

68 

69 

70class Dummy1PipelineTask(PipelineTask): 

71 ConfigClass = Dummy1Config 

72 

73 

74class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

75 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

76 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

77 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

78 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

79 

80 

81class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

82 conf1 = Field(dtype=int, default=1, doc="dummy config") 

83 

84 

85class Dummy2PipelineTask(PipelineTask): 

86 ConfigClass = Dummy2Config 

87 

88 

89class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

90 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

91 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a") 

92 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

93 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

94 

95 

96class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

97 conf1 = Field(dtype=int, default=1, doc="dummy config") 

98 

99 

100class Dummy3PipelineTask(PipelineTask): 

101 ConfigClass = Dummy3Config 

102 

103 

104# Test if a Task that does not interact with the other Tasks works fine in 

105# the graph. 

106class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")): 

107 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

108 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

109 

110 

111class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections): 

112 conf1 = Field(dtype=int, default=1, doc="dummy config") 

113 

114 

115class Dummy4PipelineTask(PipelineTask): 

116 ConfigClass = Dummy4Config 

117 

118 

119class QuantumGraphTestCase(unittest.TestCase): 

120 """Tests the various functions of a quantum graph""" 

121 

122 def setUp(self): 

123 self.config = Config( 

124 { 

125 "version": 1, 

126 "namespace": "pipe_base_test", 

127 "skypix": { 

128 "common": "htm7", 

129 "htm": { 

130 "class": "lsst.sphgeom.HtmPixelization", 

131 "max_level": 24, 

132 }, 

133 }, 

134 "elements": { 

135 "A": { 

136 "keys": [ 

137 { 

138 "name": "id", 

139 "type": "int", 

140 } 

141 ], 

142 "storage": { 

143 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

144 }, 

145 }, 

146 "B": { 

147 "keys": [ 

148 { 

149 "name": "id", 

150 "type": "int", 

151 } 

152 ], 

153 "storage": { 

154 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

155 }, 

156 }, 

157 }, 

158 "packers": {}, 

159 } 

160 ) 

161 universe = DimensionUniverse(config=self.config) 

162 # need to make a mapping of TaskDef to set of quantum 

163 quantumMap = {} 

164 tasks = [] 

165 for task, label in ( 

166 (Dummy1PipelineTask, "R"), 

167 (Dummy2PipelineTask, "S"), 

168 (Dummy3PipelineTask, "T"), 

169 (Dummy4PipelineTask, "U"), 

170 ): 

171 config = task.ConfigClass() 

172 taskDef = TaskDef(get_full_type_name(task), config, task, label) 

173 tasks.append(taskDef) 

174 quantumSet = set() 

175 connections = taskDef.connections 

176 for a, b in ((1, 2), (3, 4)): 

177 if connections.initInputs: 

178 initInputDSType = DatasetType( 

179 connections.initInput.name, 

180 tuple(), 

181 storageClass=connections.initInput.storageClass, 

182 universe=universe, 

183 ) 

184 initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))] 

185 else: 

186 initRefs = None 

187 inputDSType = DatasetType( 

188 connections.input.name, 

189 connections.input.dimensions, 

190 storageClass=connections.input.storageClass, 

191 universe=universe, 

192 ) 

193 inputRefs = [ 

194 DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) 

195 ] 

196 outputDSType = DatasetType( 

197 connections.output.name, 

198 connections.output.dimensions, 

199 storageClass=connections.output.storageClass, 

200 universe=universe, 

201 ) 

202 outputRefs = [ 

203 DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) 

204 ] 

205 quantumSet.add( 

206 Quantum( 

207 taskName=task.__qualname__, 

208 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

209 taskClass=task, 

210 initInputs=initRefs, 

211 inputs={inputDSType: inputRefs}, 

212 outputs={outputDSType: outputRefs}, 

213 ) 

214 ) 

215 quantumMap[taskDef] = quantumSet 

216 self.tasks = tasks 

217 self.quantumMap = quantumMap 

218 self.qGraph = QuantumGraph(quantumMap, metadata=METADATA, universe=universe) 

219 self.universe = universe 

220 

221 def testTaskGraph(self): 

222 for taskDef in self.quantumMap.keys(): 

223 self.assertIn(taskDef, self.qGraph.taskGraph) 

224 

225 def testGraph(self): 

226 graphSet = {q.quantum for q in self.qGraph.graph} 

227 for quantum in chain.from_iterable(self.quantumMap.values()): 

228 self.assertIn(quantum, graphSet) 

229 

230 def testGetQuantumNodeByNodeId(self): 

231 inputQuanta = tuple(self.qGraph.inputQuanta) 

232 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

233 self.assertEqual(node, inputQuanta[0]) 

234 wrongNode = uuid.uuid4() 

235 with self.assertRaises(KeyError): 

236 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

237 

238 def testPickle(self): 

239 stringify = pickle.dumps(self.qGraph) 

240 restore: QuantumGraph = pickle.loads(stringify) 

241 self.assertEqual(self.qGraph, restore) 

242 

243 def testInputQuanta(self): 

244 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

245 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs) 

246 

247 def testOutputtQuanta(self): 

248 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

249 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs) 

250 

251 def testLength(self): 

252 self.assertEqual(len(self.qGraph), 2 * len(self.tasks)) 

253 

254 def testGetQuantaForTask(self): 

255 for task in self.tasks: 

256 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

257 

258 def testGetNodesForTask(self): 

259 for task in self.tasks: 

260 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task) 

261 quanta_in_node = set(n.quantum for n in nodes) 

262 self.assertEqual(quanta_in_node, self.quantumMap[task]) 

263 

264 def testFindTasksWithInput(self): 

265 self.assertEqual( 

266 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1] 

267 ) 

268 

269 def testFindTasksWithOutput(self): 

270 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

271 

272 def testTaskWithDSType(self): 

273 self.assertEqual( 

274 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2]) 

275 ) 

276 

277 def testFindTaskDefByName(self): 

278 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0]) 

279 

280 def testFindTaskDefByLabel(self): 

281 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0]) 

282 

283 def testFindQuantaWIthDSType(self): 

284 self.assertEqual( 

285 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]] 

286 ) 

287 

288 def testAllDatasetTypes(self): 

289 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

290 truth = set() 

291 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections): 

292 for connection in conClass.allConnections.values(): # type: ignore 

293 if not isinstance(connection, cT.InitOutput): 

294 truth.add(connection.name) 

295 self.assertEqual(allDatasetTypes, truth) 

296 

297 def testSubset(self): 

298 allNodes = list(self.qGraph) 

299 subset = self.qGraph.subset(allNodes[0]) 

300 self.assertEqual(len(subset), 1) 

301 subsetList = list(subset) 

302 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) 

303 self.assertEqual(self.qGraph._buildId, subset._buildId) 

304 

305 def testSubsetToConnected(self): 

306 # False because there are two quantum chains for two distinct sets of 

307 # dimensions 

308 self.assertFalse(self.qGraph.isConnected) 

309 

310 connectedGraphs = self.qGraph.subsetToConnected() 

311 self.assertEqual(len(connectedGraphs), 4) 

312 self.assertTrue(connectedGraphs[0].isConnected) 

313 self.assertTrue(connectedGraphs[1].isConnected) 

314 self.assertTrue(connectedGraphs[2].isConnected) 

315 self.assertTrue(connectedGraphs[3].isConnected) 

316 

317 # Split out task[3] because it is expected to be on its own 

318 for cg in connectedGraphs: 

319 if self.tasks[3] in cg.taskGraph: 

320 self.assertEqual(len(cg), 1) 

321 else: 

322 self.assertEqual(len(cg), 3) 

323 

324 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

325 

326 count = 0 

327 for node in self.qGraph: 

328 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

329 count += 1 

330 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

331 count += 1 

332 if connectedGraphs[2].checkQuantumInGraph(node.quantum): 

333 count += 1 

334 if connectedGraphs[3].checkQuantumInGraph(node.quantum): 

335 count += 1 

336 self.assertEqual(len(self.qGraph), count) 

337 

338 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs} 

339 for setLen, tskSet in taskSets.items(): 

340 if setLen == 3: 

341 self.assertEqual(set(self.tasks[:-1]), tskSet) 

342 elif setLen == 1: 

343 self.assertEqual({self.tasks[-1]}, tskSet) 

344 for cg in connectedGraphs: 

345 if len(cg.taskGraph) == 1: 

346 continue 

347 allNodes = list(cg) 

348 node = cg.determineInputsToQuantumNode(allNodes[1]) 

349 self.assertEqual(set([allNodes[0]]), node) 

350 node = cg.determineInputsToQuantumNode(allNodes[1]) 

351 self.assertEqual(set([allNodes[0]]), node) 

352 

353 def testDetermineOutputsOfQuantumNode(self): 

354 testNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

355 matchNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

356 connections = set() 

357 for node in testNodes: 

358 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node)) 

359 self.assertEqual(matchNodes, connections) 

360 

361 def testDetermineConnectionsOfQuantum(self): 

362 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

363 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2]) 

364 # outputs contain nodes tested for because it is a complete graph 

365 matchNodes |= set(testNodes) 

366 connections = set() 

367 for node in testNodes: 

368 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node)) 

369 self.assertEqual(matchNodes, connections) 

370 

371 def testDetermineAnsestorsOfQuantumNode(self): 

372 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

373 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

374 matchNodes |= set(testNodes) 

375 connections = set() 

376 for node in testNodes: 

377 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node)) 

378 self.assertEqual(matchNodes, connections) 

379 

380 def testFindCycle(self): 

381 self.assertFalse(self.qGraph.findCycle()) 

382 

383 def testSaveLoad(self): 

384 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: 

385 self.qGraph.save(tmpFile) 

386 tmpFile.seek(0) 

387 restore = QuantumGraph.load(tmpFile, self.universe) 

388 self.assertEqual(self.qGraph, restore) 

389 # Load in just one node 

390 tmpFile.seek(0) 

391 nodeId = [n.nodeId for n in self.qGraph][0] 

392 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,)) 

393 self.assertEqual(len(restoreSub), 1) 

394 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId)) 

395 

396 # Different universes. 

397 tmpFile.seek(0) 

398 different_config = self.config.copy() 

399 different_config["version"] = 1_000_000 

400 different_universe = DimensionUniverse(config=different_config) 

401 with self.assertLogs("lsst.daf.butler", "INFO"): 

402 QuantumGraph.load(tmpFile, different_universe) 

403 

404 different_config["namespace"] = "incompatible" 

405 different_universe = DimensionUniverse(config=different_config) 

406 print("Trying with uni ", different_universe) 

407 tmpFile.seek(0) 

408 with self.assertRaises(RuntimeError) as cm: 

409 QuantumGraph.load(tmpFile, different_universe) 

410 self.assertIn("not compatible with", str(cm.exception)) 

411 

412 def testSaveLoadUri(self): 

413 uri = None 

414 try: 

415 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: 

416 uri = tmpFile.name 

417 self.qGraph.saveUri(uri) 

418 restore = QuantumGraph.loadUri(uri) 

419 self.assertEqual(restore.metadata, METADATA) 

420 self.assertEqual(self.qGraph, restore) 

421 nodeNumberId = random.randint(0, len(self.qGraph) - 1) 

422 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId] 

423 restoreSub = QuantumGraph.loadUri( 

424 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId 

425 ) 

426 self.assertEqual(len(restoreSub), 1) 

427 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber)) 

428 # verify that more than one node works 

429 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

430 # ensure it is a different node number 

431 while nodeNumberId2 == nodeNumberId: 

432 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

433 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2] 

434 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) 

435 self.assertEqual(len(restoreSub), 2) 

436 self.assertEqual( 

437 set(restoreSub), 

438 set( 

439 ( 

440 restore.getQuantumNodeByNodeId(nodeNumber), 

441 restore.getQuantumNodeByNodeId(nodeNumber2), 

442 ) 

443 ), 

444 ) 

445 # verify an error when requesting a non existant node number 

446 with self.assertRaises(ValueError): 

447 QuantumGraph.loadUri(uri, self.universe, nodes=(99,)) 

448 

449 # verify a graphID that does not match will be an error 

450 with self.assertRaises(ValueError): 

451 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT") 

452 

453 except Exception as e: 

454 raise e 

455 finally: 

456 if uri is not None: 

457 os.remove(uri) 

458 

459 with self.assertRaises(TypeError): 

460 self.qGraph.saveUri("test.notgraph") 

461 

462 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!") 

463 @mock_s3 

464 def testSaveLoadUriS3(self): 

465 # Test loading a quantum graph from an mock s3 store 

466 conn = boto3.resource("s3", region_name="us-east-1") 

467 conn.create_bucket(Bucket="testBucket") 

468 uri = "s3://testBucket/qgraph.qgraph" 

469 self.qGraph.saveUri(uri) 

470 restore = QuantumGraph.loadUri(uri) 

471 self.assertEqual(self.qGraph, restore) 

472 nodeId = list(self.qGraph)[0].nodeId 

473 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeId,)) 

474 self.assertEqual(len(restoreSub), 1) 

475 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId)) 

476 

477 def testContains(self): 

478 firstNode = next(iter(self.qGraph)) 

479 self.assertIn(firstNode, self.qGraph) 

480 

481 def testDimensionUniverseInSave(self): 

482 _, header = self.qGraph._buildSaveObject(returnHeader=True) 

483 # type ignore because buildSaveObject does not have method overload 

484 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore 

485 

486 

487class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

488 pass 

489 

490 

491def setup_module(module): 

492 lsst.utils.tests.init() 

493 

494 

495if __name__ == "__main__": 495 ↛ 496line 495 didn't jump to line 496, because the condition on line 495 was never true

496 lsst.utils.tests.init() 

497 unittest.main()