Coverage for tests/test_quantumGraph.py: 21%

293 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-23 03:02 -0800

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import os 

23import pickle 

24import random 

25import tempfile 

26import unittest 

27import uuid 

28from itertools import chain 

29from typing import Iterable 

30 

31import lsst.pipe.base.connectionTypes as cT 

32import lsst.utils.tests 

33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

34from lsst.pex.config import Field 

35from lsst.pipe.base import ( 

36 DatasetTypeName, 

37 PipelineTask, 

38 PipelineTaskConfig, 

39 PipelineTaskConnections, 

40 QuantumGraph, 

41 TaskDef, 

42) 

43from lsst.pipe.base.graph.quantumNode import QuantumNode 

44from lsst.utils.introspection import get_full_type_name 

45 

46METADATA = {"a": [1, 2, 3]} 

47 

48 

49class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

50 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

51 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

52 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

53 

54 

55class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

56 conf1 = Field(dtype=int, default=1, doc="dummy config") 

57 

58 

59class Dummy1PipelineTask(PipelineTask): 

60 ConfigClass = Dummy1Config 

61 

62 

63class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

64 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

65 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

66 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

67 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

68 

69 

70class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

71 conf1 = Field(dtype=int, default=1, doc="dummy config") 

72 

73 

74class Dummy2PipelineTask(PipelineTask): 

75 ConfigClass = Dummy2Config 

76 

77 

78class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

79 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

80 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a") 

81 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

82 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

83 

84 

85class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

86 conf1 = Field(dtype=int, default=1, doc="dummy config") 

87 

88 

89class Dummy3PipelineTask(PipelineTask): 

90 ConfigClass = Dummy3Config 

91 

92 

93# Test if a Task that does not interact with the other Tasks works fine in 

94# the graph. 

95class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")): 

96 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

97 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

98 

99 

100class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections): 

101 conf1 = Field(dtype=int, default=1, doc="dummy config") 

102 

103 

104class Dummy4PipelineTask(PipelineTask): 

105 ConfigClass = Dummy4Config 

106 

107 

108class QuantumGraphTestCase(unittest.TestCase): 

109 """Tests the various functions of a quantum graph""" 

110 

111 def setUp(self): 

112 self.config = Config( 

113 { 

114 "version": 1, 

115 "namespace": "pipe_base_test", 

116 "skypix": { 

117 "common": "htm7", 

118 "htm": { 

119 "class": "lsst.sphgeom.HtmPixelization", 

120 "max_level": 24, 

121 }, 

122 }, 

123 "elements": { 

124 "A": { 

125 "keys": [ 

126 { 

127 "name": "id", 

128 "type": "int", 

129 } 

130 ], 

131 "storage": { 

132 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

133 }, 

134 }, 

135 "B": { 

136 "keys": [ 

137 { 

138 "name": "id", 

139 "type": "int", 

140 } 

141 ], 

142 "storage": { 

143 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

144 }, 

145 }, 

146 }, 

147 "packers": {}, 

148 } 

149 ) 

150 universe = DimensionUniverse(config=self.config) 

151 

152 def _makeDatasetType(connection): 

153 return DatasetType( 

154 connection.name, 

155 getattr(connection, "dimensions", ()), 

156 storageClass=connection.storageClass, 

157 universe=universe, 

158 ) 

159 

160 # need to make a mapping of TaskDef to set of quantum 

161 quantumMap = {} 

162 tasks = [] 

163 initInputs = {} 

164 initOutputs = {} 

165 for task, label in ( 

166 (Dummy1PipelineTask, "R"), 

167 (Dummy2PipelineTask, "S"), 

168 (Dummy3PipelineTask, "T"), 

169 (Dummy4PipelineTask, "U"), 

170 ): 

171 config = task.ConfigClass() 

172 taskDef = TaskDef(get_full_type_name(task), config, task, label) 

173 tasks.append(taskDef) 

174 quantumSet = set() 

175 connections = taskDef.connections 

176 if connections.initInputs: 

177 initInputDSType = _makeDatasetType(connections.initInput) 

178 initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))] 

179 initInputs[taskDef] = initRefs 

180 else: 

181 initRefs = None 

182 if connections.initOutputs: 

183 initOutputDSType = _makeDatasetType(connections.initOutput) 

184 initRefs = [DatasetRef(initOutputDSType, DataCoordinate.makeEmpty(universe))] 

185 initOutputs[taskDef] = initRefs 

186 inputDSType = _makeDatasetType(connections.input) 

187 outputDSType = _makeDatasetType(connections.output) 

188 for a, b in ((1, 2), (3, 4)): 

189 inputRefs = [ 

190 DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) 

191 ] 

192 outputRefs = [ 

193 DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) 

194 ] 

195 quantumSet.add( 

196 Quantum( 

197 taskName=task.__qualname__, 

198 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

199 taskClass=task, 

200 initInputs=initRefs, 

201 inputs={inputDSType: inputRefs}, 

202 outputs={outputDSType: outputRefs}, 

203 ) 

204 ) 

205 quantumMap[taskDef] = quantumSet 

206 self.tasks = tasks 

207 self.quantumMap = quantumMap 

208 self.packagesDSType = DatasetType("packages", universe.empty, storageClass="Packages") 

209 globalInitOutputs = [DatasetRef(self.packagesDSType, DataCoordinate.makeEmpty(universe))] 

210 self.qGraph = QuantumGraph( 

211 quantumMap, 

212 metadata=METADATA, 

213 universe=universe, 

214 initInputs=initInputs, 

215 initOutputs=initOutputs, 

216 globalInitOutputs=globalInitOutputs, 

217 ) 

218 self.universe = universe 

219 

220 def testTaskGraph(self): 

221 for taskDef in self.quantumMap.keys(): 

222 self.assertIn(taskDef, self.qGraph.taskGraph) 

223 

224 def testGraph(self): 

225 graphSet = {q.quantum for q in self.qGraph.graph} 

226 for quantum in chain.from_iterable(self.quantumMap.values()): 

227 self.assertIn(quantum, graphSet) 

228 

229 def testGetQuantumNodeByNodeId(self): 

230 inputQuanta = tuple(self.qGraph.inputQuanta) 

231 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

232 self.assertEqual(node, inputQuanta[0]) 

233 wrongNode = uuid.uuid4() 

234 with self.assertRaises(KeyError): 

235 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

236 

237 def testPickle(self): 

238 stringify = pickle.dumps(self.qGraph) 

239 restore: QuantumGraph = pickle.loads(stringify) 

240 self.assertEqual(self.qGraph, restore) 

241 

242 def testInputQuanta(self): 

243 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

244 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs) 

245 

246 def testOutputQuanta(self): 

247 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

248 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs) 

249 

250 def testLength(self): 

251 self.assertEqual(len(self.qGraph), 2 * len(self.tasks)) 

252 

253 def testGetQuantaForTask(self): 

254 for task in self.tasks: 

255 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

256 

257 def testGetNumberOfQuantaForTask(self): 

258 for task in self.tasks: 

259 self.assertEqual(self.qGraph.getNumberOfQuantaForTask(task), len(self.quantumMap[task])) 

260 

261 def testGetNodesForTask(self): 

262 for task in self.tasks: 

263 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task) 

264 quanta_in_node = set(n.quantum for n in nodes) 

265 self.assertEqual(quanta_in_node, self.quantumMap[task]) 

266 

267 def testFindTasksWithInput(self): 

268 self.assertEqual( 

269 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1] 

270 ) 

271 

272 def testFindTasksWithOutput(self): 

273 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

274 

275 def testTaskWithDSType(self): 

276 self.assertEqual( 

277 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2]) 

278 ) 

279 

280 def testFindTaskDefByName(self): 

281 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0]) 

282 

283 def testFindTaskDefByLabel(self): 

284 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0]) 

285 

286 def testFindQuantaWIthDSType(self): 

287 self.assertEqual( 

288 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]] 

289 ) 

290 

291 def testAllDatasetTypes(self): 

292 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

293 truth = set() 

294 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections): 

295 for connection in conClass.allConnections.values(): # type: ignore 

296 if not isinstance(connection, cT.InitOutput): 

297 truth.add(connection.name) 

298 self.assertEqual(allDatasetTypes, truth) 

299 

300 def testSubset(self): 

301 allNodes = list(self.qGraph) 

302 subset = self.qGraph.subset(allNodes[0]) 

303 self.assertEqual(len(subset), 1) 

304 subsetList = list(subset) 

305 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) 

306 self.assertEqual(self.qGraph._buildId, subset._buildId) 

307 self.assertEqual(len(subset.globalInitOutputRefs()), 1) 

308 

309 def testSubsetToConnected(self): 

310 # False because there are two quantum chains for two distinct sets of 

311 # dimensions 

312 self.assertFalse(self.qGraph.isConnected) 

313 

314 connectedGraphs = self.qGraph.subsetToConnected() 

315 self.assertEqual(len(connectedGraphs), 4) 

316 self.assertTrue(connectedGraphs[0].isConnected) 

317 self.assertTrue(connectedGraphs[1].isConnected) 

318 self.assertTrue(connectedGraphs[2].isConnected) 

319 self.assertTrue(connectedGraphs[3].isConnected) 

320 

321 # Split out task[3] because it is expected to be on its own 

322 for cg in connectedGraphs: 

323 if self.tasks[3] in cg.taskGraph: 

324 self.assertEqual(len(cg), 1) 

325 else: 

326 self.assertEqual(len(cg), 3) 

327 

328 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

329 

330 count = 0 

331 for node in self.qGraph: 

332 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

333 count += 1 

334 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

335 count += 1 

336 if connectedGraphs[2].checkQuantumInGraph(node.quantum): 

337 count += 1 

338 if connectedGraphs[3].checkQuantumInGraph(node.quantum): 

339 count += 1 

340 self.assertEqual(len(self.qGraph), count) 

341 

342 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs} 

343 for setLen, tskSet in taskSets.items(): 

344 if setLen == 3: 

345 self.assertEqual(set(self.tasks[:-1]), tskSet) 

346 elif setLen == 1: 

347 self.assertEqual({self.tasks[-1]}, tskSet) 

348 for cg in connectedGraphs: 

349 if len(cg.taskGraph) == 1: 

350 continue 

351 allNodes = list(cg) 

352 node = cg.determineInputsToQuantumNode(allNodes[1]) 

353 self.assertEqual(set([allNodes[0]]), node) 

354 node = cg.determineInputsToQuantumNode(allNodes[1]) 

355 self.assertEqual(set([allNodes[0]]), node) 

356 

357 def testDetermineOutputsOfQuantumNode(self): 

358 testNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

359 matchNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

360 connections = set() 

361 for node in testNodes: 

362 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node)) 

363 self.assertEqual(matchNodes, connections) 

364 

365 def testDetermineConnectionsOfQuantum(self): 

366 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

367 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2]) 

368 # outputs contain nodes tested for because it is a complete graph 

369 matchNodes |= set(testNodes) 

370 connections = set() 

371 for node in testNodes: 

372 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node)) 

373 self.assertEqual(matchNodes, connections) 

374 

375 def testDetermineAnsestorsOfQuantumNode(self): 

376 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

377 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

378 matchNodes |= set(testNodes) 

379 connections = set() 

380 for node in testNodes: 

381 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node)) 

382 self.assertEqual(matchNodes, connections) 

383 

384 def testFindCycle(self): 

385 self.assertFalse(self.qGraph.findCycle()) 

386 

387 def testSaveLoad(self): 

388 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: 

389 self.qGraph.save(tmpFile) 

390 tmpFile.seek(0) 

391 restore = QuantumGraph.load(tmpFile, self.universe) 

392 self.assertEqual(self.qGraph, restore) 

393 # Load in just one node 

394 tmpFile.seek(0) 

395 nodeId = [n.nodeId for n in self.qGraph][0] 

396 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,)) 

397 self.assertEqual(len(restoreSub), 1) 

398 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId)) 

399 self.assertEqual(len(restoreSub.globalInitOutputRefs()), 1) 

400 # Check that InitInput and InitOutput refs are restored correctly. 

401 for taskDef in restore.iterTaskGraph(): 

402 if taskDef.label in ("S", "T"): 

403 refs = restore.initInputRefs(taskDef) 

404 self.assertIsNotNone(refs) 

405 self.assertGreater(len(refs), 0) 

406 if taskDef.label in ("R", "S", "T"): 

407 refs = restore.initOutputRefs(taskDef) 

408 self.assertIsNotNone(refs) 

409 self.assertGreater(len(refs), 0) 

410 

411 # Different universes. 

412 tmpFile.seek(0) 

413 different_config = self.config.copy() 

414 different_config["version"] = 1_000_000 

415 different_universe = DimensionUniverse(config=different_config) 

416 with self.assertLogs("lsst.daf.butler", "INFO"): 

417 QuantumGraph.load(tmpFile, different_universe) 

418 

419 different_config["namespace"] = "incompatible" 

420 different_universe = DimensionUniverse(config=different_config) 

421 print("Trying with uni ", different_universe) 

422 tmpFile.seek(0) 

423 with self.assertRaises(RuntimeError) as cm: 

424 QuantumGraph.load(tmpFile, different_universe) 

425 self.assertIn("not compatible with", str(cm.exception)) 

426 

427 def testSaveLoadUri(self): 

428 uri = None 

429 try: 

430 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: 

431 uri = tmpFile.name 

432 self.qGraph.saveUri(uri) 

433 restore = QuantumGraph.loadUri(uri) 

434 self.assertEqual(restore.metadata, METADATA) 

435 self.assertEqual(self.qGraph, restore) 

436 nodeNumberId = random.randint(0, len(self.qGraph) - 1) 

437 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId] 

438 restoreSub = QuantumGraph.loadUri( 

439 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId 

440 ) 

441 self.assertEqual(len(restoreSub), 1) 

442 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber)) 

443 # verify that more than one node works 

444 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

445 # ensure it is a different node number 

446 while nodeNumberId2 == nodeNumberId: 

447 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

448 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2] 

449 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) 

450 self.assertEqual(len(restoreSub), 2) 

451 self.assertEqual( 

452 set(restoreSub), 

453 set( 

454 ( 

455 restore.getQuantumNodeByNodeId(nodeNumber), 

456 restore.getQuantumNodeByNodeId(nodeNumber2), 

457 ) 

458 ), 

459 ) 

460 # verify an error when requesting a non existant node number 

461 with self.assertRaises(ValueError): 

462 QuantumGraph.loadUri(uri, self.universe, nodes=(99,)) 

463 

464 # verify a graphID that does not match will be an error 

465 with self.assertRaises(ValueError): 

466 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT") 

467 

468 except Exception as e: 

469 raise e 

470 finally: 

471 if uri is not None: 

472 os.remove(uri) 

473 

474 with self.assertRaises(TypeError): 

475 self.qGraph.saveUri("test.notgraph") 

476 

477 def testContains(self): 

478 firstNode = next(iter(self.qGraph)) 

479 self.assertIn(firstNode, self.qGraph) 

480 

481 def testDimensionUniverseInSave(self): 

482 _, header = self.qGraph._buildSaveObject(returnHeader=True) 

483 # type ignore because buildSaveObject does not have method overload 

484 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore 

485 

486 

487class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

488 pass 

489 

490 

491def setup_module(module): 

492 lsst.utils.tests.init() 

493 

494 

495if __name__ == "__main__": 495 ↛ 496line 495 didn't jump to line 496, because the condition on line 495 was never true

496 lsst.utils.tests.init() 

497 unittest.main()