Coverage for tests/test_quantumGraph.py: 20%

312 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-22 02:08 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import os 

23import pickle 

24import random 

25import tempfile 

26import unittest 

27import uuid 

28from itertools import chain 

29from typing import Iterable 

30 

31import lsst.pipe.base.connectionTypes as cT 

32import lsst.utils.tests 

33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

34from lsst.pex.config import Field 

35from lsst.pipe.base import ( 

36 DatasetTypeName, 

37 PipelineTask, 

38 PipelineTaskConfig, 

39 PipelineTaskConnections, 

40 QuantumGraph, 

41 TaskDef, 

42) 

43from lsst.pipe.base.graph.quantumNode import QuantumNode 

44from lsst.utils.introspection import get_full_type_name 

45 

46METADATA = {"a": [1, 2, 3]} 

47 

48 

49class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

50 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

51 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

52 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

53 

54 

55class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

56 conf1 = Field(dtype=int, default=1, doc="dummy config") 

57 

58 

59class Dummy1PipelineTask(PipelineTask): 

60 ConfigClass = Dummy1Config 

61 

62 

63class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

64 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

65 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

66 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

67 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

68 

69 

70class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

71 conf1 = Field(dtype=int, default=1, doc="dummy config") 

72 

73 

74class Dummy2PipelineTask(PipelineTask): 

75 ConfigClass = Dummy2Config 

76 

77 

78class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

79 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

80 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a") 

81 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

82 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

83 

84 

85class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

86 conf1 = Field(dtype=int, default=1, doc="dummy config") 

87 

88 

89class Dummy3PipelineTask(PipelineTask): 

90 ConfigClass = Dummy3Config 

91 

92 

93# Test if a Task that does not interact with the other Tasks works fine in 

94# the graph. 

95class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")): 

96 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

97 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

98 

99 

100class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections): 

101 conf1 = Field(dtype=int, default=1, doc="dummy config") 

102 

103 

104class Dummy4PipelineTask(PipelineTask): 

105 ConfigClass = Dummy4Config 

106 

107 

108class QuantumGraphTestCase(unittest.TestCase): 

109 """Tests the various functions of a quantum graph""" 

110 

111 def setUp(self): 

112 self.config = Config( 

113 { 

114 "version": 1, 

115 "namespace": "pipe_base_test", 

116 "skypix": { 

117 "common": "htm7", 

118 "htm": { 

119 "class": "lsst.sphgeom.HtmPixelization", 

120 "max_level": 24, 

121 }, 

122 }, 

123 "elements": { 

124 "A": { 

125 "keys": [ 

126 { 

127 "name": "id", 

128 "type": "int", 

129 } 

130 ], 

131 "storage": { 

132 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

133 }, 

134 }, 

135 "B": { 

136 "keys": [ 

137 { 

138 "name": "id", 

139 "type": "int", 

140 } 

141 ], 

142 "storage": { 

143 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

144 }, 

145 }, 

146 }, 

147 "packers": {}, 

148 } 

149 ) 

150 universe = DimensionUniverse(config=self.config) 

151 

152 def _makeDatasetType(connection): 

153 return DatasetType( 

154 connection.name, 

155 getattr(connection, "dimensions", ()), 

156 storageClass=connection.storageClass, 

157 universe=universe, 

158 ) 

159 

160 # need to make a mapping of TaskDef to set of quantum 

161 quantumMap = {} 

162 tasks = [] 

163 initInputs = {} 

164 initOutputs = {} 

165 dataset_types = set() 

166 for task, label in ( 

167 (Dummy1PipelineTask, "R"), 

168 (Dummy2PipelineTask, "S"), 

169 (Dummy3PipelineTask, "T"), 

170 (Dummy4PipelineTask, "U"), 

171 ): 

172 config = task.ConfigClass() 

173 taskDef = TaskDef(get_full_type_name(task), config, task, label) 

174 tasks.append(taskDef) 

175 quantumSet = set() 

176 connections = taskDef.connections 

177 if connections.initInputs: 

178 initInputDSType = _makeDatasetType(connections.initInput) 

179 initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))] 

180 initInputs[taskDef] = initRefs 

181 dataset_types.add(initInputDSType) 

182 else: 

183 initRefs = None 

184 if connections.initOutputs: 

185 initOutputDSType = _makeDatasetType(connections.initOutput) 

186 initRefs = [DatasetRef(initOutputDSType, DataCoordinate.makeEmpty(universe))] 

187 initOutputs[taskDef] = initRefs 

188 dataset_types.add(initOutputDSType) 

189 inputDSType = _makeDatasetType(connections.input) 

190 dataset_types.add(inputDSType) 

191 outputDSType = _makeDatasetType(connections.output) 

192 dataset_types.add(outputDSType) 

193 for a, b in ((1, 2), (3, 4)): 

194 inputRefs = [ 

195 DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) 

196 ] 

197 outputRefs = [ 

198 DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) 

199 ] 

200 quantumSet.add( 

201 Quantum( 

202 taskName=task.__qualname__, 

203 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

204 taskClass=task, 

205 initInputs=initRefs, 

206 inputs={inputDSType: inputRefs}, 

207 outputs={outputDSType: outputRefs}, 

208 ) 

209 ) 

210 quantumMap[taskDef] = quantumSet 

211 self.tasks = tasks 

212 self.quantumMap = quantumMap 

213 self.packagesDSType = DatasetType("packages", universe.empty, storageClass="Packages") 

214 dataset_types.add(self.packagesDSType) 

215 globalInitOutputs = [DatasetRef(self.packagesDSType, DataCoordinate.makeEmpty(universe))] 

216 self.qGraph = QuantumGraph( 

217 quantumMap, 

218 metadata=METADATA, 

219 universe=universe, 

220 initInputs=initInputs, 

221 initOutputs=initOutputs, 

222 globalInitOutputs=globalInitOutputs, 

223 registryDatasetTypes=dataset_types, 

224 ) 

225 self.universe = universe 

226 self.num_dataset_types = len(dataset_types) 

227 

228 def testTaskGraph(self): 

229 for taskDef in self.quantumMap.keys(): 

230 self.assertIn(taskDef, self.qGraph.taskGraph) 

231 

232 def testGraph(self): 

233 graphSet = {q.quantum for q in self.qGraph.graph} 

234 for quantum in chain.from_iterable(self.quantumMap.values()): 

235 self.assertIn(quantum, graphSet) 

236 

237 def testGetQuantumNodeByNodeId(self): 

238 inputQuanta = tuple(self.qGraph.inputQuanta) 

239 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

240 self.assertEqual(node, inputQuanta[0]) 

241 wrongNode = uuid.uuid4() 

242 with self.assertRaises(KeyError): 

243 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

244 

245 def testPickle(self): 

246 stringify = pickle.dumps(self.qGraph) 

247 restore: QuantumGraph = pickle.loads(stringify) 

248 self.assertEqual(self.qGraph, restore) 

249 

250 def testInputQuanta(self): 

251 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

252 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs) 

253 

254 def testOutputQuanta(self): 

255 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

256 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs) 

257 

258 def testLength(self): 

259 self.assertEqual(len(self.qGraph), 2 * len(self.tasks)) 

260 

261 def testGetQuantaForTask(self): 

262 for task in self.tasks: 

263 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

264 

265 def testGetNumberOfQuantaForTask(self): 

266 for task in self.tasks: 

267 self.assertEqual(self.qGraph.getNumberOfQuantaForTask(task), len(self.quantumMap[task])) 

268 

269 def testGetNodesForTask(self): 

270 for task in self.tasks: 

271 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task) 

272 quanta_in_node = set(n.quantum for n in nodes) 

273 self.assertEqual(quanta_in_node, self.quantumMap[task]) 

274 

275 def testFindTasksWithInput(self): 

276 self.assertEqual( 

277 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1] 

278 ) 

279 

280 def testFindTasksWithOutput(self): 

281 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

282 

283 def testTaskWithDSType(self): 

284 self.assertEqual( 

285 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2]) 

286 ) 

287 

288 def testFindTaskDefByName(self): 

289 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0]) 

290 

291 def testFindTaskDefByLabel(self): 

292 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0]) 

293 

294 def testFindQuantaWIthDSType(self): 

295 self.assertEqual( 

296 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]] 

297 ) 

298 

299 def testAllDatasetTypes(self): 

300 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

301 truth = set() 

302 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections): 

303 for connection in conClass.allConnections.values(): # type: ignore 

304 if not isinstance(connection, cT.InitOutput): 

305 truth.add(connection.name) 

306 self.assertEqual(allDatasetTypes, truth) 

307 

308 def testSubset(self): 

309 allNodes = list(self.qGraph) 

310 firstNode = allNodes[0] 

311 subset = self.qGraph.subset(firstNode) 

312 self.assertEqual(len(subset), 1) 

313 subsetList = list(subset) 

314 self.assertEqual(firstNode.quantum, subsetList[0].quantum) 

315 self.assertEqual(self.qGraph._buildId, subset._buildId) 

316 self.assertEqual(len(subset.globalInitOutputRefs()), 1) 

317 # Depending on which task was first the list can contain different 

318 # number of datasets. The first task can be either Dummy1 or Dummy4. 

319 num_types = {"R": 4, "U": 3} 

320 self.assertEqual(len(subset.registryDatasetTypes()), num_types[firstNode.taskDef.label]) 

321 

322 def testSubsetToConnected(self): 

323 # False because there are two quantum chains for two distinct sets of 

324 # dimensions 

325 self.assertFalse(self.qGraph.isConnected) 

326 

327 connectedGraphs = self.qGraph.subsetToConnected() 

328 self.assertEqual(len(connectedGraphs), 4) 

329 self.assertTrue(connectedGraphs[0].isConnected) 

330 self.assertTrue(connectedGraphs[1].isConnected) 

331 self.assertTrue(connectedGraphs[2].isConnected) 

332 self.assertTrue(connectedGraphs[3].isConnected) 

333 

334 # Split out task[3] because it is expected to be on its own 

335 for cg in connectedGraphs: 

336 if self.tasks[3] in cg.taskGraph: 

337 self.assertEqual(len(cg), 1) 

338 else: 

339 self.assertEqual(len(cg), 3) 

340 

341 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

342 

343 count = 0 

344 for node in self.qGraph: 

345 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

346 count += 1 

347 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

348 count += 1 

349 if connectedGraphs[2].checkQuantumInGraph(node.quantum): 

350 count += 1 

351 if connectedGraphs[3].checkQuantumInGraph(node.quantum): 

352 count += 1 

353 self.assertEqual(len(self.qGraph), count) 

354 

355 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs} 

356 for setLen, tskSet in taskSets.items(): 

357 if setLen == 3: 

358 self.assertEqual(set(self.tasks[:-1]), tskSet) 

359 elif setLen == 1: 

360 self.assertEqual({self.tasks[-1]}, tskSet) 

361 for cg in connectedGraphs: 

362 if len(cg.taskGraph) == 1: 

363 continue 

364 allNodes = list(cg) 

365 node = cg.determineInputsToQuantumNode(allNodes[1]) 

366 self.assertEqual(set([allNodes[0]]), node) 

367 node = cg.determineInputsToQuantumNode(allNodes[1]) 

368 self.assertEqual(set([allNodes[0]]), node) 

369 

370 def testDetermineOutputsOfQuantumNode(self): 

371 testNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

372 matchNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

373 connections = set() 

374 for node in testNodes: 

375 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node)) 

376 self.assertEqual(matchNodes, connections) 

377 

378 def testDetermineConnectionsOfQuantum(self): 

379 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

380 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2]) 

381 # outputs contain nodes tested for because it is a complete graph 

382 matchNodes |= set(testNodes) 

383 connections = set() 

384 for node in testNodes: 

385 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node)) 

386 self.assertEqual(matchNodes, connections) 

387 

388 def testDetermineAnsestorsOfQuantumNode(self): 

389 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

390 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

391 matchNodes |= set(testNodes) 

392 connections = set() 

393 for node in testNodes: 

394 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node)) 

395 self.assertEqual(matchNodes, connections) 

396 

397 def testFindCycle(self): 

398 self.assertFalse(self.qGraph.findCycle()) 

399 

400 def testSaveLoad(self): 

401 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: 

402 self.qGraph.save(tmpFile) 

403 tmpFile.seek(0) 

404 restore = QuantumGraph.load(tmpFile, self.universe) 

405 self.assertEqual(self.qGraph, restore) 

406 # Load in just one node 

407 tmpFile.seek(0) 

408 nodeId = [n.nodeId for n in self.qGraph][0] 

409 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,)) 

410 self.assertEqual(len(restoreSub), 1) 

411 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId)) 

412 self.assertEqual(len(restoreSub.globalInitOutputRefs()), 1) 

413 self.assertEqual(len(restoreSub.registryDatasetTypes()), self.num_dataset_types) 

414 # Check that InitInput and InitOutput refs are restored correctly. 

415 for taskDef in restore.iterTaskGraph(): 

416 if taskDef.label in ("S", "T"): 

417 refs = restore.initInputRefs(taskDef) 

418 self.assertIsNotNone(refs) 

419 self.assertGreater(len(refs), 0) 

420 if taskDef.label in ("R", "S", "T"): 

421 refs = restore.initOutputRefs(taskDef) 

422 self.assertIsNotNone(refs) 

423 self.assertGreater(len(refs), 0) 

424 

425 # Different universes. 

426 tmpFile.seek(0) 

427 different_config = self.config.copy() 

428 different_config["version"] = 1_000_000 

429 different_universe = DimensionUniverse(config=different_config) 

430 with self.assertLogs("lsst.daf.butler", "INFO"): 

431 QuantumGraph.load(tmpFile, different_universe) 

432 

433 different_config["namespace"] = "incompatible" 

434 different_universe = DimensionUniverse(config=different_config) 

435 print("Trying with uni ", different_universe) 

436 tmpFile.seek(0) 

437 with self.assertRaises(RuntimeError) as cm: 

438 QuantumGraph.load(tmpFile, different_universe) 

439 self.assertIn("not compatible with", str(cm.exception)) 

440 

441 def testSaveLoadUri(self): 

442 uri = None 

443 try: 

444 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: 

445 uri = tmpFile.name 

446 self.qGraph.saveUri(uri) 

447 restore = QuantumGraph.loadUri(uri) 

448 self.assertEqual(restore.metadata, METADATA) 

449 self.assertEqual(self.qGraph, restore) 

450 nodeNumberId = random.randint(0, len(self.qGraph) - 1) 

451 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId] 

452 restoreSub = QuantumGraph.loadUri( 

453 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId 

454 ) 

455 self.assertEqual(len(restoreSub), 1) 

456 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber)) 

457 # verify that more than one node works 

458 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

459 # ensure it is a different node number 

460 while nodeNumberId2 == nodeNumberId: 

461 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

462 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2] 

463 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) 

464 self.assertEqual(len(restoreSub), 2) 

465 self.assertEqual( 

466 set(restoreSub), 

467 set( 

468 ( 

469 restore.getQuantumNodeByNodeId(nodeNumber), 

470 restore.getQuantumNodeByNodeId(nodeNumber2), 

471 ) 

472 ), 

473 ) 

474 # verify an error when requesting a non existant node number 

475 with self.assertRaises(ValueError): 

476 QuantumGraph.loadUri(uri, self.universe, nodes=(99,)) 

477 

478 # verify a graphID that does not match will be an error 

479 with self.assertRaises(ValueError): 

480 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT") 

481 

482 except Exception as e: 

483 raise e 

484 finally: 

485 if uri is not None: 

486 os.remove(uri) 

487 

488 with self.assertRaises(TypeError): 

489 self.qGraph.saveUri("test.notgraph") 

490 

491 def testSaveLoadNoRegistryDatasetTypes(self): 

492 """Test for reading quantum that is missing registry dataset types. 

493 

494 This test depends on internals of QuantumGraph implementation, in 

495 particular that empty list of registry dataset types is not stored, 

496 which makes save file identical to the "old" format. 

497 """ 

498 # Reset the list, this is safe as QuantumGraph itself does not use it. 

499 self.qGraph._registryDatasetTypes = [] 

500 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: 

501 self.qGraph.save(tmpFile) 

502 tmpFile.seek(0) 

503 restore = QuantumGraph.load(tmpFile, self.universe) 

504 self.assertEqual(self.qGraph, restore) 

505 self.assertEqual(restore.registryDatasetTypes(), []) 

506 

507 def testContains(self): 

508 firstNode = next(iter(self.qGraph)) 

509 self.assertIn(firstNode, self.qGraph) 

510 

511 def testDimensionUniverseInSave(self): 

512 _, header = self.qGraph._buildSaveObject(returnHeader=True) 

513 # type ignore because buildSaveObject does not have method overload 

514 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore 

515 

516 

517class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

518 pass 

519 

520 

521def setup_module(module): 

522 lsst.utils.tests.init() 

523 

524 

525if __name__ == "__main__": 525 ↛ 526line 525 didn't jump to line 526, because the condition on line 525 was never true

526 lsst.utils.tests.init() 

527 unittest.main()