Coverage for tests/test_quantumGraph.py: 22%

309 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-09 03:06 -0800

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import os 

23import pickle 

24import random 

25import tempfile 

26import unittest 

27import uuid 

28from itertools import chain 

29from typing import Iterable 

30 

31import lsst.pipe.base.connectionTypes as cT 

32import lsst.utils.tests 

33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

34from lsst.pex.config import Field 

35from lsst.pipe.base import ( 

36 DatasetTypeName, 

37 PipelineTask, 

38 PipelineTaskConfig, 

39 PipelineTaskConnections, 

40 QuantumGraph, 

41 TaskDef, 

42) 

43from lsst.pipe.base.graph.quantumNode import QuantumNode 

44from lsst.utils.introspection import get_full_type_name 

45 

46try: 

47 import boto3 

48 from moto import mock_s3 

49except ImportError: 

50 boto3 = None 

51 

52 def mock_s3(cls): 

53 """A no-op decorator in case moto mock_s3 can not be imported.""" 

54 return cls 

55 

56 

57METADATA = {"a": [1, 2, 3]} 

58 

59 

60class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

61 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

62 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

63 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

64 

65 

66class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

67 conf1 = Field(dtype=int, default=1, doc="dummy config") 

68 

69 

70class Dummy1PipelineTask(PipelineTask): 

71 ConfigClass = Dummy1Config 

72 

73 

74class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

75 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

76 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

77 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

78 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

79 

80 

81class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

82 conf1 = Field(dtype=int, default=1, doc="dummy config") 

83 

84 

85class Dummy2PipelineTask(PipelineTask): 

86 ConfigClass = Dummy2Config 

87 

88 

89class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

90 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

91 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a") 

92 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

93 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

94 

95 

96class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

97 conf1 = Field(dtype=int, default=1, doc="dummy config") 

98 

99 

100class Dummy3PipelineTask(PipelineTask): 

101 ConfigClass = Dummy3Config 

102 

103 

104# Test if a Task that does not interact with the other Tasks works fine in 

105# the graph. 

106class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")): 

107 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

108 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

109 

110 

111class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections): 

112 conf1 = Field(dtype=int, default=1, doc="dummy config") 

113 

114 

115class Dummy4PipelineTask(PipelineTask): 

116 ConfigClass = Dummy4Config 

117 

118 

119class QuantumGraphTestCase(unittest.TestCase): 

120 """Tests the various functions of a quantum graph""" 

121 

122 def setUp(self): 

123 self.config = Config( 

124 { 

125 "version": 1, 

126 "namespace": "pipe_base_test", 

127 "skypix": { 

128 "common": "htm7", 

129 "htm": { 

130 "class": "lsst.sphgeom.HtmPixelization", 

131 "max_level": 24, 

132 }, 

133 }, 

134 "elements": { 

135 "A": { 

136 "keys": [ 

137 { 

138 "name": "id", 

139 "type": "int", 

140 } 

141 ], 

142 "storage": { 

143 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

144 }, 

145 }, 

146 "B": { 

147 "keys": [ 

148 { 

149 "name": "id", 

150 "type": "int", 

151 } 

152 ], 

153 "storage": { 

154 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

155 }, 

156 }, 

157 }, 

158 "packers": {}, 

159 } 

160 ) 

161 universe = DimensionUniverse(config=self.config) 

162 

163 def _makeDatasetType(connection): 

164 return DatasetType( 

165 connection.name, 

166 getattr(connection, "dimensions", ()), 

167 storageClass=connection.storageClass, 

168 universe=universe, 

169 ) 

170 

171 # need to make a mapping of TaskDef to set of quantum 

172 quantumMap = {} 

173 tasks = [] 

174 initInputs = {} 

175 initOutputs = {} 

176 for task, label in ( 

177 (Dummy1PipelineTask, "R"), 

178 (Dummy2PipelineTask, "S"), 

179 (Dummy3PipelineTask, "T"), 

180 (Dummy4PipelineTask, "U"), 

181 ): 

182 config = task.ConfigClass() 

183 taskDef = TaskDef(get_full_type_name(task), config, task, label) 

184 tasks.append(taskDef) 

185 quantumSet = set() 

186 connections = taskDef.connections 

187 if connections.initInputs: 

188 initInputDSType = _makeDatasetType(connections.initInput) 

189 initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))] 

190 initInputs[taskDef] = initRefs 

191 else: 

192 initRefs = None 

193 if connections.initOutputs: 

194 initOutputDSType = _makeDatasetType(connections.initOutput) 

195 initRefs = [DatasetRef(initOutputDSType, DataCoordinate.makeEmpty(universe))] 

196 initOutputs[taskDef] = initRefs 

197 inputDSType = _makeDatasetType(connections.input) 

198 outputDSType = _makeDatasetType(connections.output) 

199 for a, b in ((1, 2), (3, 4)): 

200 inputRefs = [ 

201 DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) 

202 ] 

203 outputRefs = [ 

204 DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) 

205 ] 

206 quantumSet.add( 

207 Quantum( 

208 taskName=task.__qualname__, 

209 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

210 taskClass=task, 

211 initInputs=initRefs, 

212 inputs={inputDSType: inputRefs}, 

213 outputs={outputDSType: outputRefs}, 

214 ) 

215 ) 

216 quantumMap[taskDef] = quantumSet 

217 self.tasks = tasks 

218 self.quantumMap = quantumMap 

219 self.qGraph = QuantumGraph( 

220 quantumMap, metadata=METADATA, universe=universe, initInputs=initInputs, initOutputs=initOutputs 

221 ) 

222 self.universe = universe 

223 

224 def testTaskGraph(self): 

225 for taskDef in self.quantumMap.keys(): 

226 self.assertIn(taskDef, self.qGraph.taskGraph) 

227 

228 def testGraph(self): 

229 graphSet = {q.quantum for q in self.qGraph.graph} 

230 for quantum in chain.from_iterable(self.quantumMap.values()): 

231 self.assertIn(quantum, graphSet) 

232 

233 def testGetQuantumNodeByNodeId(self): 

234 inputQuanta = tuple(self.qGraph.inputQuanta) 

235 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

236 self.assertEqual(node, inputQuanta[0]) 

237 wrongNode = uuid.uuid4() 

238 with self.assertRaises(KeyError): 

239 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

240 

241 def testPickle(self): 

242 stringify = pickle.dumps(self.qGraph) 

243 restore: QuantumGraph = pickle.loads(stringify) 

244 self.assertEqual(self.qGraph, restore) 

245 

246 def testInputQuanta(self): 

247 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

248 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs) 

249 

250 def testOutputtQuanta(self): 

251 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

252 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs) 

253 

254 def testLength(self): 

255 self.assertEqual(len(self.qGraph), 2 * len(self.tasks)) 

256 

257 def testGetQuantaForTask(self): 

258 for task in self.tasks: 

259 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

260 

261 def testGetNumberOfQuantaForTask(self): 

262 for task in self.tasks: 

263 self.assertEqual(self.qGraph.getNumberOfQuantaForTask(task), len(self.quantumMap[task])) 

264 

265 def testGetNodesForTask(self): 

266 for task in self.tasks: 

267 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task) 

268 quanta_in_node = set(n.quantum for n in nodes) 

269 self.assertEqual(quanta_in_node, self.quantumMap[task]) 

270 

271 def testFindTasksWithInput(self): 

272 self.assertEqual( 

273 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1] 

274 ) 

275 

276 def testFindTasksWithOutput(self): 

277 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

278 

279 def testTaskWithDSType(self): 

280 self.assertEqual( 

281 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2]) 

282 ) 

283 

284 def testFindTaskDefByName(self): 

285 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0]) 

286 

287 def testFindTaskDefByLabel(self): 

288 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0]) 

289 

290 def testFindQuantaWIthDSType(self): 

291 self.assertEqual( 

292 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]] 

293 ) 

294 

295 def testAllDatasetTypes(self): 

296 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

297 truth = set() 

298 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections): 

299 for connection in conClass.allConnections.values(): # type: ignore 

300 if not isinstance(connection, cT.InitOutput): 

301 truth.add(connection.name) 

302 self.assertEqual(allDatasetTypes, truth) 

303 

304 def testSubset(self): 

305 allNodes = list(self.qGraph) 

306 subset = self.qGraph.subset(allNodes[0]) 

307 self.assertEqual(len(subset), 1) 

308 subsetList = list(subset) 

309 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) 

310 self.assertEqual(self.qGraph._buildId, subset._buildId) 

311 

312 def testSubsetToConnected(self): 

313 # False because there are two quantum chains for two distinct sets of 

314 # dimensions 

315 self.assertFalse(self.qGraph.isConnected) 

316 

317 connectedGraphs = self.qGraph.subsetToConnected() 

318 self.assertEqual(len(connectedGraphs), 4) 

319 self.assertTrue(connectedGraphs[0].isConnected) 

320 self.assertTrue(connectedGraphs[1].isConnected) 

321 self.assertTrue(connectedGraphs[2].isConnected) 

322 self.assertTrue(connectedGraphs[3].isConnected) 

323 

324 # Split out task[3] because it is expected to be on its own 

325 for cg in connectedGraphs: 

326 if self.tasks[3] in cg.taskGraph: 

327 self.assertEqual(len(cg), 1) 

328 else: 

329 self.assertEqual(len(cg), 3) 

330 

331 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

332 

333 count = 0 

334 for node in self.qGraph: 

335 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

336 count += 1 

337 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

338 count += 1 

339 if connectedGraphs[2].checkQuantumInGraph(node.quantum): 

340 count += 1 

341 if connectedGraphs[3].checkQuantumInGraph(node.quantum): 

342 count += 1 

343 self.assertEqual(len(self.qGraph), count) 

344 

345 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs} 

346 for setLen, tskSet in taskSets.items(): 

347 if setLen == 3: 

348 self.assertEqual(set(self.tasks[:-1]), tskSet) 

349 elif setLen == 1: 

350 self.assertEqual({self.tasks[-1]}, tskSet) 

351 for cg in connectedGraphs: 

352 if len(cg.taskGraph) == 1: 

353 continue 

354 allNodes = list(cg) 

355 node = cg.determineInputsToQuantumNode(allNodes[1]) 

356 self.assertEqual(set([allNodes[0]]), node) 

357 node = cg.determineInputsToQuantumNode(allNodes[1]) 

358 self.assertEqual(set([allNodes[0]]), node) 

359 

360 def testDetermineOutputsOfQuantumNode(self): 

361 testNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

362 matchNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

363 connections = set() 

364 for node in testNodes: 

365 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node)) 

366 self.assertEqual(matchNodes, connections) 

367 

368 def testDetermineConnectionsOfQuantum(self): 

369 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

370 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2]) 

371 # outputs contain nodes tested for because it is a complete graph 

372 matchNodes |= set(testNodes) 

373 connections = set() 

374 for node in testNodes: 

375 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node)) 

376 self.assertEqual(matchNodes, connections) 

377 

378 def testDetermineAnsestorsOfQuantumNode(self): 

379 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

380 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

381 matchNodes |= set(testNodes) 

382 connections = set() 

383 for node in testNodes: 

384 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node)) 

385 self.assertEqual(matchNodes, connections) 

386 

387 def testFindCycle(self): 

388 self.assertFalse(self.qGraph.findCycle()) 

389 

390 def testSaveLoad(self): 

391 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: 

392 self.qGraph.save(tmpFile) 

393 tmpFile.seek(0) 

394 restore = QuantumGraph.load(tmpFile, self.universe) 

395 self.assertEqual(self.qGraph, restore) 

396 # Load in just one node 

397 tmpFile.seek(0) 

398 nodeId = [n.nodeId for n in self.qGraph][0] 

399 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,)) 

400 self.assertEqual(len(restoreSub), 1) 

401 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId)) 

402 # Check that InitInput and InitOutput refs are restored correctly. 

403 for taskDef in restore.iterTaskGraph(): 

404 if taskDef.label in ("S", "T"): 

405 refs = restore.initInputRefs(taskDef) 

406 self.assertIsNotNone(refs) 

407 self.assertGreater(len(refs), 0) 

408 if taskDef.label in ("R", "S", "T"): 

409 refs = restore.initOutputRefs(taskDef) 

410 self.assertIsNotNone(refs) 

411 self.assertGreater(len(refs), 0) 

412 

413 # Different universes. 

414 tmpFile.seek(0) 

415 different_config = self.config.copy() 

416 different_config["version"] = 1_000_000 

417 different_universe = DimensionUniverse(config=different_config) 

418 with self.assertLogs("lsst.daf.butler", "INFO"): 

419 QuantumGraph.load(tmpFile, different_universe) 

420 

421 different_config["namespace"] = "incompatible" 

422 different_universe = DimensionUniverse(config=different_config) 

423 print("Trying with uni ", different_universe) 

424 tmpFile.seek(0) 

425 with self.assertRaises(RuntimeError) as cm: 

426 QuantumGraph.load(tmpFile, different_universe) 

427 self.assertIn("not compatible with", str(cm.exception)) 

428 

429 def testSaveLoadUri(self): 

430 uri = None 

431 try: 

432 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: 

433 uri = tmpFile.name 

434 self.qGraph.saveUri(uri) 

435 restore = QuantumGraph.loadUri(uri) 

436 self.assertEqual(restore.metadata, METADATA) 

437 self.assertEqual(self.qGraph, restore) 

438 nodeNumberId = random.randint(0, len(self.qGraph) - 1) 

439 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId] 

440 restoreSub = QuantumGraph.loadUri( 

441 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId 

442 ) 

443 self.assertEqual(len(restoreSub), 1) 

444 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber)) 

445 # verify that more than one node works 

446 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

447 # ensure it is a different node number 

448 while nodeNumberId2 == nodeNumberId: 

449 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

450 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2] 

451 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) 

452 self.assertEqual(len(restoreSub), 2) 

453 self.assertEqual( 

454 set(restoreSub), 

455 set( 

456 ( 

457 restore.getQuantumNodeByNodeId(nodeNumber), 

458 restore.getQuantumNodeByNodeId(nodeNumber2), 

459 ) 

460 ), 

461 ) 

462 # verify an error when requesting a non existant node number 

463 with self.assertRaises(ValueError): 

464 QuantumGraph.loadUri(uri, self.universe, nodes=(99,)) 

465 

466 # verify a graphID that does not match will be an error 

467 with self.assertRaises(ValueError): 

468 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT") 

469 

470 except Exception as e: 

471 raise e 

472 finally: 

473 if uri is not None: 

474 os.remove(uri) 

475 

476 with self.assertRaises(TypeError): 

477 self.qGraph.saveUri("test.notgraph") 

478 

479 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!") 

480 @mock_s3 

481 def testSaveLoadUriS3(self): 

482 # Test loading a quantum graph from an mock s3 store 

483 conn = boto3.resource("s3", region_name="us-east-1") 

484 conn.create_bucket(Bucket="testBucket") 

485 uri = "s3://testBucket/qgraph.qgraph" 

486 self.qGraph.saveUri(uri) 

487 restore = QuantumGraph.loadUri(uri) 

488 self.assertEqual(self.qGraph, restore) 

489 nodeId = list(self.qGraph)[0].nodeId 

490 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeId,)) 

491 self.assertEqual(len(restoreSub), 1) 

492 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId)) 

493 

494 def testContains(self): 

495 firstNode = next(iter(self.qGraph)) 

496 self.assertIn(firstNode, self.qGraph) 

497 

498 def testDimensionUniverseInSave(self): 

499 _, header = self.qGraph._buildSaveObject(returnHeader=True) 

500 # type ignore because buildSaveObject does not have method overload 

501 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore 

502 

503 

504class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

505 pass 

506 

507 

508def setup_module(module): 

509 lsst.utils.tests.init() 

510 

511 

512if __name__ == "__main__": 512 ↛ 513line 512 didn't jump to line 513, because the condition on line 512 was never true

513 lsst.utils.tests.init() 

514 unittest.main()