Coverage for tests/test_quantumGraph.py: 19%

359 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-11 03:31 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28import os 

29import pickle 

30import random 

31import tempfile 

32import unittest 

33import uuid 

34from itertools import chain 

35 

36import lsst.pipe.base.connectionTypes as cT 

37import lsst.utils.tests 

38from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

39from lsst.pex.config import Field 

40from lsst.pipe.base import ( 

41 DatasetTypeName, 

42 PipelineTask, 

43 PipelineTaskConfig, 

44 PipelineTaskConnections, 

45 QuantumGraph, 

46 TaskDef, 

47) 

48from lsst.pipe.base.graph.quantumNode import BuildId 

49from lsst.pipe.base.tests.util import check_output_run, get_output_refs 

50from lsst.utils.introspection import get_full_type_name 

51from lsst.utils.packages import Packages 

52 

53METADATA = {"a": [1, 2, 3]} 

54 

55 

56class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

57 """Dummy connections class #1.""" 

58 

59 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

60 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

61 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

62 

63 

64class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

65 """Dummy config #1.""" 

66 

67 conf1 = Field[int](default=1, doc="dummy config") 

68 

69 

70class Dummy1PipelineTask(PipelineTask): 

71 """Dummy pipeline task #1.""" 

72 

73 ConfigClass = Dummy1Config 

74 

75 

76class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

77 """Dummy connections class #2.""" 

78 

79 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

80 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

81 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

82 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

83 

84 

85class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

86 """Dummy config #2.""" 

87 

88 conf1 = Field[int](default=1, doc="dummy config") 

89 

90 

91class Dummy2PipelineTask(PipelineTask): 

92 """Dummy pipeline task #3.""" 

93 

94 ConfigClass = Dummy2Config 

95 

96 

97class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

98 """Dummy connections class #3.""" 

99 

100 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

101 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a") 

102 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

103 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

104 

105 

106class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

107 """Dummy config #3.""" 

108 

109 conf1 = Field[int](default=1, doc="dummy config") 

110 

111 

112class Dummy3PipelineTask(PipelineTask): 

113 """Dummy pipeline task #3.""" 

114 

115 ConfigClass = Dummy3Config 

116 

117 

118# Test if a Task that does not interact with the other Tasks works fine in 

119# the graph. 

120class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")): 

121 """Dummy connections class #4.""" 

122 

123 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

124 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

125 

126 

127class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections): 

128 """Dummy config #4.""" 

129 

130 conf1 = Field[int](default=1, doc="dummy config") 

131 

132 

133class Dummy4PipelineTask(PipelineTask): 

134 """Dummy pipeline task #4.""" 

135 

136 ConfigClass = Dummy4Config 

137 

138 

139class QuantumGraphTestCase(unittest.TestCase): 

140 """Tests the various functions of a quantum graph.""" 

141 

142 input_collection = "inputs" 

143 output_run = "run" 

144 

145 def setUp(self) -> None: 

146 self.config = Config( 

147 { 

148 "version": 1, 

149 "namespace": "pipe_base_test", 

150 "skypix": { 

151 "common": "htm7", 

152 "htm": { 

153 "class": "lsst.sphgeom.HtmPixelization", 

154 "max_level": 24, 

155 }, 

156 }, 

157 "elements": { 

158 "A": { 

159 "keys": [ 

160 { 

161 "name": "id", 

162 "type": "int", 

163 } 

164 ], 

165 "storage": { 

166 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

167 }, 

168 }, 

169 "B": { 

170 "keys": [ 

171 { 

172 "name": "id", 

173 "type": "int", 

174 } 

175 ], 

176 "storage": { 

177 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

178 }, 

179 }, 

180 }, 

181 "packers": {}, 

182 } 

183 ) 

184 universe = DimensionUniverse(config=self.config) 

185 

186 def _makeDatasetType(connection): 

187 return DatasetType( 

188 connection.name, 

189 getattr(connection, "dimensions", ()), 

190 storageClass=connection.storageClass, 

191 universe=universe, 

192 ) 

193 

194 # need to make a mapping of TaskDef to set of quantum 

195 quantumMap = {} 

196 tasks = [] 

197 initInputs = {} 

198 initOutputs = {} 

199 dataset_types = set() 

200 init_dataset_refs: dict[DatasetType, DatasetRef] = {} 

201 dataset_refs: dict[tuple[DatasetType, DataCoordinate], DatasetRef] = {} 

202 for task, label in ( 

203 (Dummy1PipelineTask, "R"), 

204 (Dummy2PipelineTask, "S"), 

205 (Dummy3PipelineTask, "T"), 

206 (Dummy4PipelineTask, "U"), 

207 ): 

208 config = task.ConfigClass() 

209 taskDef = TaskDef(get_full_type_name(task), config, task, label) 

210 tasks.append(taskDef) 

211 quantumSet = set() 

212 connections = taskDef.connections 

213 if connections.initInputs: 

214 initInputDSType = _makeDatasetType(connections.initInput) 

215 if (ref := init_dataset_refs.get(initInputDSType)) is not None: 

216 initRefs = [ref] 

217 else: 

218 initRefs = [ 

219 DatasetRef( 

220 initInputDSType, 

221 DataCoordinate.make_empty(universe), 

222 run=self.input_collection, 

223 ) 

224 ] 

225 initInputs[taskDef] = initRefs 

226 dataset_types.add(initInputDSType) 

227 else: 

228 initRefs = None 

229 if connections.initOutputs: 

230 initOutputDSType = _makeDatasetType(connections.initOutput) 

231 initRefs = [ 

232 DatasetRef(initOutputDSType, DataCoordinate.make_empty(universe), run=self.output_run) 

233 ] 

234 init_dataset_refs[initOutputDSType] = initRefs[0] 

235 initOutputs[taskDef] = initRefs 

236 dataset_types.add(initOutputDSType) 

237 inputDSType = _makeDatasetType(connections.input) 

238 dataset_types.add(inputDSType) 

239 outputDSType = _makeDatasetType(connections.output) 

240 dataset_types.add(outputDSType) 

241 for a, b in ((1, 2), (3, 4)): 

242 dataId = DataCoordinate.standardize({"A": a, "B": b}, universe=universe) 

243 if (ref := dataset_refs.get((inputDSType, dataId))) is None: 

244 inputRefs = [DatasetRef(inputDSType, dataId, run=self.input_collection)] 

245 else: 

246 inputRefs = [ref] 

247 outputRefs = [DatasetRef(outputDSType, dataId, run=self.output_run)] 

248 dataset_refs[(outputDSType, dataId)] = outputRefs[0] 

249 quantumSet.add( 

250 Quantum( 

251 taskName=task.__qualname__, 

252 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

253 taskClass=task, 

254 initInputs=initRefs, 

255 inputs={inputDSType: inputRefs}, 

256 outputs={outputDSType: outputRefs}, 

257 ) 

258 ) 

259 quantumMap[taskDef] = quantumSet 

260 self.tasks = tasks 

261 self.quanta_by_task_label = {task_def.label: set(quanta) for task_def, quanta in quantumMap.items()} 

262 self.quantumMap = quantumMap 

263 self.packagesDSType = DatasetType("packages", universe.empty, storageClass="Packages") 

264 dataset_types.add(self.packagesDSType) 

265 globalInitOutputs = [ 

266 DatasetRef(self.packagesDSType, DataCoordinate.make_empty(universe), run=self.output_run) 

267 ] 

268 self.qGraph = QuantumGraph( 

269 quantumMap, 

270 metadata=METADATA, 

271 universe=universe, 

272 initInputs=initInputs, 

273 initOutputs=initOutputs, 

274 globalInitOutputs=globalInitOutputs, 

275 registryDatasetTypes=dataset_types, 

276 ) 

277 self.universe = universe 

278 self.num_dataset_types = len(dataset_types) 

279 

280 def testTaskGraph(self) -> None: 

281 task_def_labels: list[str] = [] 

282 for taskDef in self.quantumMap: 

283 task_def_labels.append(taskDef.label) 

284 self.assertIn(taskDef, self.qGraph.taskGraph) 

285 self.assertCountEqual(task_def_labels, self.qGraph.pipeline_graph.tasks.keys()) 

286 

287 def testGraph(self) -> None: 

288 graphSet = {q.quantum for q in self.qGraph.graph} 

289 for quantum in chain.from_iterable(self.quantumMap.values()): 

290 self.assertIn(quantum, graphSet) 

291 

292 def testGetQuantumNodeByNodeId(self) -> None: 

293 inputQuanta = tuple(self.qGraph.inputQuanta) 

294 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

295 self.assertEqual(node, inputQuanta[0]) 

296 wrongNode = uuid.uuid4() 

297 with self.assertRaises(KeyError): 

298 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

299 

300 def testPickle(self) -> None: 

301 stringify = pickle.dumps(self.qGraph) 

302 restore: QuantumGraph = pickle.loads(stringify) 

303 self.assertEqual(self.qGraph, restore) 

304 

305 def testInputQuanta(self) -> None: 

306 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

307 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs) 

308 

309 def testOutputQuanta(self) -> None: 

310 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

311 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs) 

312 

313 def testLength(self) -> None: 

314 self.assertEqual(len(self.qGraph), 2 * len(self.tasks)) 

315 

316 def testGetQuantaForTask(self) -> None: 

317 for task in self.tasks: 

318 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

319 

320 def testGetNumberOfQuantaForTask(self) -> None: 

321 for task in self.tasks: 

322 self.assertEqual(self.qGraph.getNumberOfQuantaForTask(task), len(self.quantumMap[task])) 

323 

324 def testGetNodesForTask(self) -> None: 

325 for task in self.tasks: 

326 nodes = list(self.qGraph.getNodesForTask(task)) 

327 quanta_in_node = {n.quantum for n in nodes} 

328 self.assertEqual(quanta_in_node, self.quantumMap[task]) 

329 for node in nodes: 

330 self.assertEqual( 

331 node.task_node.task_class_name, 

332 self.qGraph.pipeline_graph.tasks[node.task_node.label].task_class_name, 

333 ) 

334 

335 def testFindTasksWithInput(self) -> None: 

336 self.assertEqual( 

337 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1] 

338 ) 

339 

340 def testFindTasksWithOutput(self) -> None: 

341 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

342 

343 def testTaskWithDSType(self) -> None: 

344 self.assertEqual( 

345 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2]) 

346 ) 

347 

348 def testFindTaskDefByName(self) -> None: 

349 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0]) 

350 

351 def testFindTaskDefByLabel(self) -> None: 

352 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0]) 

353 

354 def testFindQuantaWIthDSType(self) -> None: 

355 self.assertEqual( 

356 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]] 

357 ) 

358 

359 def testAllDatasetTypes(self) -> None: 

360 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

361 truth = set() 

362 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections): 

363 for connection in conClass.allConnections.values(): # type: ignore 

364 if not isinstance(connection, cT.InitOutput): 

365 truth.add(connection.name) 

366 self.assertEqual(allDatasetTypes, truth) 

367 

368 def testSubset(self) -> None: 

369 allNodes = list(self.qGraph) 

370 firstNode = allNodes[0] 

371 subset = self.qGraph.subset(firstNode) 

372 self.assertEqual(len(subset), 1) 

373 subsetList = list(subset) 

374 self.assertEqual(firstNode.quantum, subsetList[0].quantum) 

375 self.assertEqual(self.qGraph._buildId, subset._buildId) 

376 self.assertEqual(len(subset.globalInitOutputRefs()), 1) 

377 # Depending on which task was first the list can contain different 

378 # number of datasets. The first task can be either Dummy1 or Dummy4. 

379 num_types = {"R": 4, "U": 3} 

380 self.assertEqual(len(subset.registryDatasetTypes()), num_types[firstNode.taskDef.label]) 

381 

382 def testSubsetToConnected(self) -> None: 

383 # False because there are two quantum chains for two distinct sets of 

384 # dimensions 

385 self.assertFalse(self.qGraph.isConnected) 

386 

387 connectedGraphs = self.qGraph.subsetToConnected() 

388 self.assertEqual(len(connectedGraphs), 4) 

389 self.assertTrue(connectedGraphs[0].isConnected) 

390 self.assertTrue(connectedGraphs[1].isConnected) 

391 self.assertTrue(connectedGraphs[2].isConnected) 

392 self.assertTrue(connectedGraphs[3].isConnected) 

393 

394 # Split out task[3] because it is expected to be on its own 

395 for cg in connectedGraphs: 

396 if self.tasks[3] in cg.taskGraph: 

397 self.assertEqual(len(cg), 1) 

398 else: 

399 self.assertEqual(len(cg), 3) 

400 

401 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

402 

403 count = 0 

404 for node in self.qGraph: 

405 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

406 count += 1 

407 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

408 count += 1 

409 if connectedGraphs[2].checkQuantumInGraph(node.quantum): 

410 count += 1 

411 if connectedGraphs[3].checkQuantumInGraph(node.quantum): 

412 count += 1 

413 self.assertEqual(len(self.qGraph), count) 

414 

415 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs} 

416 for setLen, tskSet in taskSets.items(): 

417 if setLen == 3: 

418 self.assertEqual(set(self.tasks[:-1]), tskSet) 

419 elif setLen == 1: 

420 self.assertEqual({self.tasks[-1]}, tskSet) 

421 for cg in connectedGraphs: 

422 if len(cg.taskGraph) == 1: 

423 continue 

424 allNodes = list(cg) 

425 nodes = cg.determineInputsToQuantumNode(allNodes[1]) 

426 self.assertEqual({allNodes[0]}, nodes) 

427 nodes = cg.determineInputsToQuantumNode(allNodes[1]) 

428 self.assertEqual({allNodes[0]}, nodes) 

429 

430 def testDetermineOutputsOfQuantumNode(self) -> None: 

431 testNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

432 matchNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

433 connections = set() 

434 for node in testNodes: 

435 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node)) 

436 self.assertEqual(matchNodes, connections) 

437 

438 def testDetermineConnectionsOfQuantum(self) -> None: 

439 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

440 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2]) 

441 # outputs contain nodes tested for because it is a complete graph 

442 matchNodes |= set(testNodes) 

443 connections = set() 

444 for node in testNodes: 

445 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node)) 

446 self.assertEqual(matchNodes, connections) 

447 

448 def testDetermineAnsestorsOfQuantumNode(self) -> None: 

449 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

450 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

451 matchNodes |= set(testNodes) 

452 connections = set() 

453 for node in testNodes: 

454 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node)) 

455 self.assertEqual(matchNodes, connections) 

456 

457 def testFindCycle(self) -> None: 

458 self.assertFalse(self.qGraph.findCycle()) 

459 

460 def testSaveLoad(self) -> None: 

461 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: 

462 self.qGraph.save(tmpFile) 

463 tmpFile.seek(0) 

464 restore = QuantumGraph.load(tmpFile, self.universe) 

465 self.assertEqual(self.qGraph, restore) 

466 # Load in just one node 

467 tmpFile.seek(0) 

468 nodeId = [n.nodeId for n in self.qGraph][0] 

469 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,)) 

470 self.assertEqual(len(restoreSub), 1) 

471 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId)) 

472 self.assertEqual(len(restoreSub.globalInitOutputRefs()), 1) 

473 self.assertEqual(len(restoreSub.registryDatasetTypes()), self.num_dataset_types) 

474 # Check that InitInput and InitOutput refs are restored correctly. 

475 for taskDef in restore.iterTaskGraph(): 

476 if taskDef.label in ("S", "T"): 

477 refs = restore.initInputRefs(taskDef) 

478 self.assertIsNotNone(refs) 

479 assert refs is not None 

480 self.assertGreater(len(refs), 0) 

481 if taskDef.label in ("R", "S", "T"): 

482 refs = restore.initOutputRefs(taskDef) 

483 self.assertIsNotNone(refs) 

484 assert refs is not None 

485 self.assertGreater(len(refs), 0) 

486 

487 # Different universes. 

488 tmpFile.seek(0) 

489 different_config = self.config.copy() 

490 different_config["version"] = 1_000_000 

491 different_universe = DimensionUniverse(config=different_config) 

492 with self.assertLogs("lsst.daf.butler", "INFO"): 

493 QuantumGraph.load(tmpFile, different_universe) 

494 

495 different_config["namespace"] = "incompatible" 

496 different_universe = DimensionUniverse(config=different_config) 

497 print("Trying with uni ", different_universe) 

498 tmpFile.seek(0) 

499 with self.assertRaises(RuntimeError) as cm: 

500 QuantumGraph.load(tmpFile, different_universe) 

501 self.assertIn("not compatible with", str(cm.exception)) 

502 

503 def testSaveLoadUri(self) -> None: 

504 uri = None 

505 try: 

506 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: 

507 uri = tmpFile.name 

508 self.qGraph.saveUri(uri) 

509 restore = QuantumGraph.loadUri(uri) 

510 self.assertEqual(restore.metadata, self.qGraph.metadata) 

511 self.assertEqual(self.qGraph, restore) 

512 nodeNumberId = random.randint(0, len(self.qGraph) - 1) 

513 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId] 

514 restoreSub = QuantumGraph.loadUri( 

515 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId 

516 ) 

517 self.assertEqual(len(restoreSub), 1) 

518 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber)) 

519 # verify that more than one node works 

520 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

521 # ensure it is a different node number 

522 while nodeNumberId2 == nodeNumberId: 

523 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

524 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2] 

525 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) 

526 self.assertEqual(len(restoreSub), 2) 

527 self.assertEqual( 

528 set(restoreSub), 

529 { 

530 restore.getQuantumNodeByNodeId(nodeNumber), 

531 restore.getQuantumNodeByNodeId(nodeNumber2), 

532 }, 

533 ) 

534 # verify an error when requesting a non existant node number 

535 with self.assertRaises(ValueError): 

536 QuantumGraph.loadUri(uri, self.universe, nodes=(uuid.uuid4(),)) 

537 

538 # verify a graphID that does not match will be an error 

539 with self.assertRaises(ValueError): 

540 QuantumGraph.loadUri(uri, self.universe, graphID=BuildId("NOTRIGHT")) 

541 

542 except Exception as e: 

543 raise e 

544 finally: 

545 if uri is not None: 

546 os.remove(uri) 

547 

548 with self.assertRaises(TypeError): 

549 self.qGraph.saveUri("test.notgraph") 

550 

551 def testSaveLoadNoRegistryDatasetTypes(self) -> None: 

552 """Test for reading quantum that is missing registry dataset types. 

553 

554 This test depends on internals of QuantumGraph implementation, in 

555 particular that empty list of registry dataset types is not stored, 

556 which makes save file identical to the "old" format. 

557 """ 

558 # Reset the list, this is safe as QuantumGraph itself does not use it. 

559 self.qGraph._registryDatasetTypes = [] 

560 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: 

561 self.qGraph.save(tmpFile) 

562 tmpFile.seek(0) 

563 restore = QuantumGraph.load(tmpFile, self.universe) 

564 self.assertEqual(self.qGraph, restore) 

565 self.assertEqual(restore.registryDatasetTypes(), []) 

566 

567 def testContains(self) -> None: 

568 firstNode = next(iter(self.qGraph)) 

569 self.assertIn(firstNode, self.qGraph) 

570 

571 def testDimensionUniverseInSave(self) -> None: 

572 _, header = self.qGraph._buildSaveObject(returnHeader=True) 

573 # type ignore because buildSaveObject does not have method overload 

574 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore 

575 

576 def testUpdateRun(self) -> None: 

577 """Test for QuantumGraph.updateRun method.""" 

578 self.assertEqual(check_output_run(self.qGraph, self.output_run), []) 

579 output_refs = get_output_refs(self.qGraph) 

580 self.assertGreater(len(output_refs), 0) 

581 graph_id = self.qGraph.graphID 

582 

583 self.qGraph.updateRun("updated-run") 

584 self.assertEqual(check_output_run(self.qGraph, "updated-run"), []) 

585 self.assertEqual(self.qGraph.graphID, graph_id) 

586 output_refs2 = get_output_refs(self.qGraph) 

587 self.assertEqual(len(output_refs2), len(output_refs)) 

588 # All output dataset IDs must be updated. 

589 self.assertTrue(set(ref.id for ref in output_refs).isdisjoint(set(ref.id for ref in output_refs2))) 

590 

591 # Also update metadata. 

592 self.qGraph.updateRun("updated-run2", metadata_key="output_run") 

593 self.assertEqual(check_output_run(self.qGraph, "updated-run2"), []) 

594 self.assertEqual(self.qGraph.graphID, graph_id) 

595 assert self.qGraph.metadata is not None 

596 self.assertIn("output_run", self.qGraph.metadata) 

597 self.assertEqual(self.qGraph.metadata["output_run"], "updated-run2") 

598 

599 # Update graph ID. 

600 self.qGraph.updateRun("updated-run3", metadata_key="output_run", update_graph_id=True) 

601 self.assertEqual(check_output_run(self.qGraph, "updated-run3"), []) 

602 self.assertNotEqual(self.qGraph.graphID, graph_id) 

603 

604 def testMetadataPackage(self) -> None: 

605 """Test package versions added to QuantumGraph metadata.""" 

606 packages = Packages.fromSystem() 

607 self.assertFalse(self.qGraph.metadata["packages"].difference(packages)) 

608 

609 def test_get_task_quanta(self) -> None: 

610 for task_label in self.qGraph.pipeline_graph.tasks.keys(): 

611 quanta = self.qGraph.get_task_quanta(task_label) 

612 self.assertCountEqual(quanta.values(), self.quanta_by_task_label[task_label]) 

613 

614 def testGetSummary(self) -> None: 

615 """Test for QuantumGraph.getSummary method.""" 

616 summary = self.qGraph.getSummary() 

617 self.assertEqual(self.qGraph.graphID, summary.graphID) 

618 self.assertEqual(len(summary.qgraphTaskSummaries), len(self.qGraph.taskGraph)) 

619 

620 

621class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

622 """Run file leak tests.""" 

623 

624 

625def setup_module(module) -> None: 

626 """Configure pytest.""" 

627 lsst.utils.tests.init() 

628 

629 

630if __name__ == "__main__": 

631 lsst.utils.tests.init() 

632 unittest.main()