Coverage for tests/test_quantumGraph.py: 19%

342 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-08-23 10:31 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import os 

23import pickle 

24import random 

25import tempfile 

26import unittest 

27import uuid 

28from collections.abc import Iterable 

29from itertools import chain 

30 

31import lsst.pipe.base.connectionTypes as cT 

32import lsst.utils.tests 

33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

34from lsst.pex.config import Field 

35from lsst.pipe.base import ( 

36 DatasetTypeName, 

37 PipelineTask, 

38 PipelineTaskConfig, 

39 PipelineTaskConnections, 

40 QuantumGraph, 

41 TaskDef, 

42) 

43from lsst.pipe.base.graph.quantumNode import BuildId, QuantumNode 

44from lsst.pipe.base.tests.util import check_output_run, get_output_refs 

45from lsst.utils.introspection import get_full_type_name 

46 

47METADATA = {"a": [1, 2, 3]} 

48 

49 

50class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

51 """Dummy connections class #1.""" 

52 

53 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

54 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

55 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

56 

57 

58class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

59 """Dummy config #1.""" 

60 

61 conf1 = Field[int](default=1, doc="dummy config") 

62 

63 

64class Dummy1PipelineTask(PipelineTask): 

65 """Dummy pipeline task #1.""" 

66 

67 ConfigClass = Dummy1Config 

68 

69 

70class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

71 """Dummy connections class #2.""" 

72 

73 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

74 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

75 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

76 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

77 

78 

79class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

80 """Dummy config #2.""" 

81 

82 conf1 = Field[int](default=1, doc="dummy config") 

83 

84 

85class Dummy2PipelineTask(PipelineTask): 

86 """Dummy pipeline task #3.""" 

87 

88 ConfigClass = Dummy2Config 

89 

90 

91class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

92 """Dummy connections class #3.""" 

93 

94 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

95 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a") 

96 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

97 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

98 

99 

100class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

101 """Dummy config #3.""" 

102 

103 conf1 = Field[int](default=1, doc="dummy config") 

104 

105 

106class Dummy3PipelineTask(PipelineTask): 

107 """Dummy pipeline task #3.""" 

108 

109 ConfigClass = Dummy3Config 

110 

111 

112# Test if a Task that does not interact with the other Tasks works fine in 

113# the graph. 

114class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")): 

115 """Dummy connections class #4.""" 

116 

117 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

118 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

119 

120 

121class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections): 

122 """Dummy config #4.""" 

123 

124 conf1 = Field[int](default=1, doc="dummy config") 

125 

126 

127class Dummy4PipelineTask(PipelineTask): 

128 """Dummy pipeline task #4.""" 

129 

130 ConfigClass = Dummy4Config 

131 

132 

133class QuantumGraphTestCase(unittest.TestCase): 

134 """Tests the various functions of a quantum graph""" 

135 

136 input_collection = "inputs" 

137 output_run = "run" 

138 

139 def setUp(self) -> None: 

140 self.config = Config( 

141 { 

142 "version": 1, 

143 "namespace": "pipe_base_test", 

144 "skypix": { 

145 "common": "htm7", 

146 "htm": { 

147 "class": "lsst.sphgeom.HtmPixelization", 

148 "max_level": 24, 

149 }, 

150 }, 

151 "elements": { 

152 "A": { 

153 "keys": [ 

154 { 

155 "name": "id", 

156 "type": "int", 

157 } 

158 ], 

159 "storage": { 

160 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

161 }, 

162 }, 

163 "B": { 

164 "keys": [ 

165 { 

166 "name": "id", 

167 "type": "int", 

168 } 

169 ], 

170 "storage": { 

171 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

172 }, 

173 }, 

174 }, 

175 "packers": {}, 

176 } 

177 ) 

178 universe = DimensionUniverse(config=self.config) 

179 

180 def _makeDatasetType(connection): 

181 return DatasetType( 

182 connection.name, 

183 getattr(connection, "dimensions", ()), 

184 storageClass=connection.storageClass, 

185 universe=universe, 

186 ) 

187 

188 # need to make a mapping of TaskDef to set of quantum 

189 quantumMap = {} 

190 tasks = [] 

191 initInputs = {} 

192 initOutputs = {} 

193 dataset_types = set() 

194 init_dataset_refs: dict[DatasetType, DatasetRef] = {} 

195 dataset_refs: dict[tuple[DatasetType, DataCoordinate], DatasetRef] = {} 

196 for task, label in ( 

197 (Dummy1PipelineTask, "R"), 

198 (Dummy2PipelineTask, "S"), 

199 (Dummy3PipelineTask, "T"), 

200 (Dummy4PipelineTask, "U"), 

201 ): 

202 config = task.ConfigClass() 

203 taskDef = TaskDef(get_full_type_name(task), config, task, label) 

204 tasks.append(taskDef) 

205 quantumSet = set() 

206 connections = taskDef.connections 

207 if connections.initInputs: 

208 initInputDSType = _makeDatasetType(connections.initInput) 

209 if (ref := init_dataset_refs.get(initInputDSType)) is not None: 

210 initRefs = [ref] 

211 else: 

212 initRefs = [ 

213 DatasetRef( 

214 initInputDSType, 

215 DataCoordinate.makeEmpty(universe), 

216 run=self.input_collection, 

217 ) 

218 ] 

219 initInputs[taskDef] = initRefs 

220 dataset_types.add(initInputDSType) 

221 else: 

222 initRefs = None 

223 if connections.initOutputs: 

224 initOutputDSType = _makeDatasetType(connections.initOutput) 

225 initRefs = [ 

226 DatasetRef(initOutputDSType, DataCoordinate.makeEmpty(universe), run=self.output_run) 

227 ] 

228 init_dataset_refs[initOutputDSType] = initRefs[0] 

229 initOutputs[taskDef] = initRefs 

230 dataset_types.add(initOutputDSType) 

231 inputDSType = _makeDatasetType(connections.input) 

232 dataset_types.add(inputDSType) 

233 outputDSType = _makeDatasetType(connections.output) 

234 dataset_types.add(outputDSType) 

235 for a, b in ((1, 2), (3, 4)): 

236 dataId = DataCoordinate.standardize({"A": a, "B": b}, universe=universe) 

237 if (ref := dataset_refs.get((inputDSType, dataId))) is None: 

238 inputRefs = [DatasetRef(inputDSType, dataId, run=self.input_collection)] 

239 else: 

240 inputRefs = [ref] 

241 outputRefs = [DatasetRef(outputDSType, dataId, run=self.output_run)] 

242 dataset_refs[(outputDSType, dataId)] = outputRefs[0] 

243 quantumSet.add( 

244 Quantum( 

245 taskName=task.__qualname__, 

246 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

247 taskClass=task, 

248 initInputs=initRefs, 

249 inputs={inputDSType: inputRefs}, 

250 outputs={outputDSType: outputRefs}, 

251 ) 

252 ) 

253 quantumMap[taskDef] = quantumSet 

254 self.tasks = tasks 

255 self.quantumMap = quantumMap 

256 self.packagesDSType = DatasetType("packages", universe.empty, storageClass="Packages") 

257 dataset_types.add(self.packagesDSType) 

258 globalInitOutputs = [ 

259 DatasetRef(self.packagesDSType, DataCoordinate.makeEmpty(universe), run=self.output_run) 

260 ] 

261 self.qGraph = QuantumGraph( 

262 quantumMap, 

263 metadata=METADATA, 

264 universe=universe, 

265 initInputs=initInputs, 

266 initOutputs=initOutputs, 

267 globalInitOutputs=globalInitOutputs, 

268 registryDatasetTypes=dataset_types, 

269 ) 

270 self.universe = universe 

271 self.num_dataset_types = len(dataset_types) 

272 

273 def testTaskGraph(self) -> None: 

274 for taskDef in self.quantumMap: 

275 self.assertIn(taskDef, self.qGraph.taskGraph) 

276 

277 def testGraph(self) -> None: 

278 graphSet = {q.quantum for q in self.qGraph.graph} 

279 for quantum in chain.from_iterable(self.quantumMap.values()): 

280 self.assertIn(quantum, graphSet) 

281 

282 def testGetQuantumNodeByNodeId(self) -> None: 

283 inputQuanta = tuple(self.qGraph.inputQuanta) 

284 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

285 self.assertEqual(node, inputQuanta[0]) 

286 wrongNode = uuid.uuid4() 

287 with self.assertRaises(KeyError): 

288 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

289 

290 def testPickle(self) -> None: 

291 stringify = pickle.dumps(self.qGraph) 

292 restore: QuantumGraph = pickle.loads(stringify) 

293 self.assertEqual(self.qGraph, restore) 

294 

295 def testInputQuanta(self) -> None: 

296 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

297 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs) 

298 

299 def testOutputQuanta(self) -> None: 

300 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

301 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs) 

302 

303 def testLength(self) -> None: 

304 self.assertEqual(len(self.qGraph), 2 * len(self.tasks)) 

305 

306 def testGetQuantaForTask(self) -> None: 

307 for task in self.tasks: 

308 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

309 

310 def testGetNumberOfQuantaForTask(self) -> None: 

311 for task in self.tasks: 

312 self.assertEqual(self.qGraph.getNumberOfQuantaForTask(task), len(self.quantumMap[task])) 

313 

314 def testGetNodesForTask(self) -> None: 

315 for task in self.tasks: 

316 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task) 

317 quanta_in_node = {n.quantum for n in nodes} 

318 self.assertEqual(quanta_in_node, self.quantumMap[task]) 

319 

320 def testFindTasksWithInput(self) -> None: 

321 self.assertEqual( 

322 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1] 

323 ) 

324 

325 def testFindTasksWithOutput(self) -> None: 

326 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

327 

328 def testTaskWithDSType(self) -> None: 

329 self.assertEqual( 

330 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2]) 

331 ) 

332 

333 def testFindTaskDefByName(self) -> None: 

334 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0]) 

335 

336 def testFindTaskDefByLabel(self) -> None: 

337 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0]) 

338 

339 def testFindQuantaWIthDSType(self) -> None: 

340 self.assertEqual( 

341 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]] 

342 ) 

343 

344 def testAllDatasetTypes(self) -> None: 

345 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

346 truth = set() 

347 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections): 

348 for connection in conClass.allConnections.values(): # type: ignore 

349 if not isinstance(connection, cT.InitOutput): 

350 truth.add(connection.name) 

351 self.assertEqual(allDatasetTypes, truth) 

352 

353 def testSubset(self) -> None: 

354 allNodes = list(self.qGraph) 

355 firstNode = allNodes[0] 

356 subset = self.qGraph.subset(firstNode) 

357 self.assertEqual(len(subset), 1) 

358 subsetList = list(subset) 

359 self.assertEqual(firstNode.quantum, subsetList[0].quantum) 

360 self.assertEqual(self.qGraph._buildId, subset._buildId) 

361 self.assertEqual(len(subset.globalInitOutputRefs()), 1) 

362 # Depending on which task was first the list can contain different 

363 # number of datasets. The first task can be either Dummy1 or Dummy4. 

364 num_types = {"R": 4, "U": 3} 

365 self.assertEqual(len(subset.registryDatasetTypes()), num_types[firstNode.taskDef.label]) 

366 

367 def testSubsetToConnected(self) -> None: 

368 # False because there are two quantum chains for two distinct sets of 

369 # dimensions 

370 self.assertFalse(self.qGraph.isConnected) 

371 

372 connectedGraphs = self.qGraph.subsetToConnected() 

373 self.assertEqual(len(connectedGraphs), 4) 

374 self.assertTrue(connectedGraphs[0].isConnected) 

375 self.assertTrue(connectedGraphs[1].isConnected) 

376 self.assertTrue(connectedGraphs[2].isConnected) 

377 self.assertTrue(connectedGraphs[3].isConnected) 

378 

379 # Split out task[3] because it is expected to be on its own 

380 for cg in connectedGraphs: 

381 if self.tasks[3] in cg.taskGraph: 

382 self.assertEqual(len(cg), 1) 

383 else: 

384 self.assertEqual(len(cg), 3) 

385 

386 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

387 

388 count = 0 

389 for node in self.qGraph: 

390 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

391 count += 1 

392 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

393 count += 1 

394 if connectedGraphs[2].checkQuantumInGraph(node.quantum): 

395 count += 1 

396 if connectedGraphs[3].checkQuantumInGraph(node.quantum): 

397 count += 1 

398 self.assertEqual(len(self.qGraph), count) 

399 

400 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs} 

401 for setLen, tskSet in taskSets.items(): 

402 if setLen == 3: 

403 self.assertEqual(set(self.tasks[:-1]), tskSet) 

404 elif setLen == 1: 

405 self.assertEqual({self.tasks[-1]}, tskSet) 

406 for cg in connectedGraphs: 

407 if len(cg.taskGraph) == 1: 

408 continue 

409 allNodes = list(cg) 

410 nodes = cg.determineInputsToQuantumNode(allNodes[1]) 

411 self.assertEqual({allNodes[0]}, nodes) 

412 nodes = cg.determineInputsToQuantumNode(allNodes[1]) 

413 self.assertEqual({allNodes[0]}, nodes) 

414 

415 def testDetermineOutputsOfQuantumNode(self) -> None: 

416 testNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

417 matchNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

418 connections = set() 

419 for node in testNodes: 

420 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node)) 

421 self.assertEqual(matchNodes, connections) 

422 

423 def testDetermineConnectionsOfQuantum(self) -> None: 

424 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

425 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2]) 

426 # outputs contain nodes tested for because it is a complete graph 

427 matchNodes |= set(testNodes) 

428 connections = set() 

429 for node in testNodes: 

430 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node)) 

431 self.assertEqual(matchNodes, connections) 

432 

433 def testDetermineAnsestorsOfQuantumNode(self) -> None: 

434 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

435 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

436 matchNodes |= set(testNodes) 

437 connections = set() 

438 for node in testNodes: 

439 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node)) 

440 self.assertEqual(matchNodes, connections) 

441 

442 def testFindCycle(self) -> None: 

443 self.assertFalse(self.qGraph.findCycle()) 

444 

445 def testSaveLoad(self) -> None: 

446 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: 

447 self.qGraph.save(tmpFile) 

448 tmpFile.seek(0) 

449 restore = QuantumGraph.load(tmpFile, self.universe) 

450 self.assertEqual(self.qGraph, restore) 

451 # Load in just one node 

452 tmpFile.seek(0) 

453 nodeId = [n.nodeId for n in self.qGraph][0] 

454 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,)) 

455 self.assertEqual(len(restoreSub), 1) 

456 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId)) 

457 self.assertEqual(len(restoreSub.globalInitOutputRefs()), 1) 

458 self.assertEqual(len(restoreSub.registryDatasetTypes()), self.num_dataset_types) 

459 # Check that InitInput and InitOutput refs are restored correctly. 

460 for taskDef in restore.iterTaskGraph(): 

461 if taskDef.label in ("S", "T"): 

462 refs = restore.initInputRefs(taskDef) 

463 self.assertIsNotNone(refs) 

464 assert refs is not None 

465 self.assertGreater(len(refs), 0) 

466 if taskDef.label in ("R", "S", "T"): 

467 refs = restore.initOutputRefs(taskDef) 

468 self.assertIsNotNone(refs) 

469 assert refs is not None 

470 self.assertGreater(len(refs), 0) 

471 

472 # Different universes. 

473 tmpFile.seek(0) 

474 different_config = self.config.copy() 

475 different_config["version"] = 1_000_000 

476 different_universe = DimensionUniverse(config=different_config) 

477 with self.assertLogs("lsst.daf.butler", "INFO"): 

478 QuantumGraph.load(tmpFile, different_universe) 

479 

480 different_config["namespace"] = "incompatible" 

481 different_universe = DimensionUniverse(config=different_config) 

482 print("Trying with uni ", different_universe) 

483 tmpFile.seek(0) 

484 with self.assertRaises(RuntimeError) as cm: 

485 QuantumGraph.load(tmpFile, different_universe) 

486 self.assertIn("not compatible with", str(cm.exception)) 

487 

488 def testSaveLoadUri(self) -> None: 

489 uri = None 

490 try: 

491 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: 

492 uri = tmpFile.name 

493 self.qGraph.saveUri(uri) 

494 restore = QuantumGraph.loadUri(uri) 

495 self.assertEqual(restore.metadata, METADATA) 

496 self.assertEqual(self.qGraph, restore) 

497 nodeNumberId = random.randint(0, len(self.qGraph) - 1) 

498 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId] 

499 restoreSub = QuantumGraph.loadUri( 

500 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId 

501 ) 

502 self.assertEqual(len(restoreSub), 1) 

503 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber)) 

504 # verify that more than one node works 

505 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

506 # ensure it is a different node number 

507 while nodeNumberId2 == nodeNumberId: 

508 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

509 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2] 

510 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) 

511 self.assertEqual(len(restoreSub), 2) 

512 self.assertEqual( 

513 set(restoreSub), 

514 { 

515 restore.getQuantumNodeByNodeId(nodeNumber), 

516 restore.getQuantumNodeByNodeId(nodeNumber2), 

517 }, 

518 ) 

519 # verify an error when requesting a non existant node number 

520 with self.assertRaises(ValueError): 

521 QuantumGraph.loadUri(uri, self.universe, nodes=(uuid.uuid4(),)) 

522 

523 # verify a graphID that does not match will be an error 

524 with self.assertRaises(ValueError): 

525 QuantumGraph.loadUri(uri, self.universe, graphID=BuildId("NOTRIGHT")) 

526 

527 except Exception as e: 

528 raise e 

529 finally: 

530 if uri is not None: 

531 os.remove(uri) 

532 

533 with self.assertRaises(TypeError): 

534 self.qGraph.saveUri("test.notgraph") 

535 

536 def testSaveLoadNoRegistryDatasetTypes(self) -> None: 

537 """Test for reading quantum that is missing registry dataset types. 

538 

539 This test depends on internals of QuantumGraph implementation, in 

540 particular that empty list of registry dataset types is not stored, 

541 which makes save file identical to the "old" format. 

542 """ 

543 # Reset the list, this is safe as QuantumGraph itself does not use it. 

544 self.qGraph._registryDatasetTypes = [] 

545 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: 

546 self.qGraph.save(tmpFile) 

547 tmpFile.seek(0) 

548 restore = QuantumGraph.load(tmpFile, self.universe) 

549 self.assertEqual(self.qGraph, restore) 

550 self.assertEqual(restore.registryDatasetTypes(), []) 

551 

552 def testContains(self) -> None: 

553 firstNode = next(iter(self.qGraph)) 

554 self.assertIn(firstNode, self.qGraph) 

555 

556 def testDimensionUniverseInSave(self) -> None: 

557 _, header = self.qGraph._buildSaveObject(returnHeader=True) 

558 # type ignore because buildSaveObject does not have method overload 

559 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore 

560 

561 def testUpdateRun(self) -> None: 

562 """Test for QuantumGraph.updateRun method.""" 

563 self.assertEqual(check_output_run(self.qGraph, self.output_run), []) 

564 output_refs = get_output_refs(self.qGraph) 

565 self.assertGreater(len(output_refs), 0) 

566 graph_id = self.qGraph.graphID 

567 

568 self.qGraph.updateRun("updated-run") 

569 self.assertEqual(check_output_run(self.qGraph, "updated-run"), []) 

570 self.assertEqual(self.qGraph.graphID, graph_id) 

571 output_refs2 = get_output_refs(self.qGraph) 

572 self.assertEqual(len(output_refs2), len(output_refs)) 

573 # All output dataset IDs must be updated. 

574 self.assertTrue(set(ref.id for ref in output_refs).isdisjoint(set(ref.id for ref in output_refs2))) 

575 

576 # Also update metadata. 

577 self.qGraph.updateRun("updated-run2", metadata_key="ouput_run") 

578 self.assertEqual(check_output_run(self.qGraph, "updated-run2"), []) 

579 self.assertEqual(self.qGraph.graphID, graph_id) 

580 assert self.qGraph.metadata is not None 

581 self.assertIn("ouput_run", self.qGraph.metadata) 

582 self.assertEqual(self.qGraph.metadata["ouput_run"], "updated-run2") 

583 

584 # Update graph ID. 

585 self.qGraph.updateRun("updated-run3", metadata_key="ouput_run", update_graph_id=True) 

586 self.assertEqual(check_output_run(self.qGraph, "updated-run3"), []) 

587 self.assertNotEqual(self.qGraph.graphID, graph_id) 

588 

589 

590class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

591 """Run file leak tests.""" 

592 

593 

594def setup_module(module) -> None: 

595 """Configure pytest.""" 

596 lsst.utils.tests.init() 

597 

598 

599if __name__ == "__main__": 

600 lsst.utils.tests.init() 

601 unittest.main()