Coverage for tests/test_quantumGraph.py: 19%

342 statements  

« prev     ^ index     » next       coverage.py v7.4.2, created at 2024-02-21 10:57 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28import os 

29import pickle 

30import random 

31import tempfile 

32import unittest 

33import uuid 

34from collections.abc import Iterable 

35from itertools import chain 

36 

37import lsst.pipe.base.connectionTypes as cT 

38import lsst.utils.tests 

39from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

40from lsst.pex.config import Field 

41from lsst.pipe.base import ( 

42 DatasetTypeName, 

43 PipelineTask, 

44 PipelineTaskConfig, 

45 PipelineTaskConnections, 

46 QuantumGraph, 

47 TaskDef, 

48) 

49from lsst.pipe.base.graph.quantumNode import BuildId, QuantumNode 

50from lsst.pipe.base.tests.util import check_output_run, get_output_refs 

51from lsst.utils.introspection import get_full_type_name 

52 

53METADATA = {"a": [1, 2, 3]} 

54 

55 

56class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

57 """Dummy connections class #1.""" 

58 

59 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

60 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

61 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

62 

63 

64class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

65 """Dummy config #1.""" 

66 

67 conf1 = Field[int](default=1, doc="dummy config") 

68 

69 

70class Dummy1PipelineTask(PipelineTask): 

71 """Dummy pipeline task #1.""" 

72 

73 ConfigClass = Dummy1Config 

74 

75 

76class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

77 """Dummy connections class #2.""" 

78 

79 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

80 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

81 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

82 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

83 

84 

85class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

86 """Dummy config #2.""" 

87 

88 conf1 = Field[int](default=1, doc="dummy config") 

89 

90 

91class Dummy2PipelineTask(PipelineTask): 

92 """Dummy pipeline task #3.""" 

93 

94 ConfigClass = Dummy2Config 

95 

96 

97class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

98 """Dummy connections class #3.""" 

99 

100 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

101 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a") 

102 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

103 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

104 

105 

106class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

107 """Dummy config #3.""" 

108 

109 conf1 = Field[int](default=1, doc="dummy config") 

110 

111 

112class Dummy3PipelineTask(PipelineTask): 

113 """Dummy pipeline task #3.""" 

114 

115 ConfigClass = Dummy3Config 

116 

117 

118# Test if a Task that does not interact with the other Tasks works fine in 

119# the graph. 

120class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")): 

121 """Dummy connections class #4.""" 

122 

123 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

124 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

125 

126 

127class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections): 

128 """Dummy config #4.""" 

129 

130 conf1 = Field[int](default=1, doc="dummy config") 

131 

132 

133class Dummy4PipelineTask(PipelineTask): 

134 """Dummy pipeline task #4.""" 

135 

136 ConfigClass = Dummy4Config 

137 

138 

139class QuantumGraphTestCase(unittest.TestCase): 

140 """Tests the various functions of a quantum graph.""" 

141 

142 input_collection = "inputs" 

143 output_run = "run" 

144 

145 def setUp(self) -> None: 

146 self.config = Config( 

147 { 

148 "version": 1, 

149 "namespace": "pipe_base_test", 

150 "skypix": { 

151 "common": "htm7", 

152 "htm": { 

153 "class": "lsst.sphgeom.HtmPixelization", 

154 "max_level": 24, 

155 }, 

156 }, 

157 "elements": { 

158 "A": { 

159 "keys": [ 

160 { 

161 "name": "id", 

162 "type": "int", 

163 } 

164 ], 

165 "storage": { 

166 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

167 }, 

168 }, 

169 "B": { 

170 "keys": [ 

171 { 

172 "name": "id", 

173 "type": "int", 

174 } 

175 ], 

176 "storage": { 

177 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

178 }, 

179 }, 

180 }, 

181 "packers": {}, 

182 } 

183 ) 

184 universe = DimensionUniverse(config=self.config) 

185 

186 def _makeDatasetType(connection): 

187 return DatasetType( 

188 connection.name, 

189 getattr(connection, "dimensions", ()), 

190 storageClass=connection.storageClass, 

191 universe=universe, 

192 ) 

193 

194 # need to make a mapping of TaskDef to set of quantum 

195 quantumMap = {} 

196 tasks = [] 

197 initInputs = {} 

198 initOutputs = {} 

199 dataset_types = set() 

200 init_dataset_refs: dict[DatasetType, DatasetRef] = {} 

201 dataset_refs: dict[tuple[DatasetType, DataCoordinate], DatasetRef] = {} 

202 for task, label in ( 

203 (Dummy1PipelineTask, "R"), 

204 (Dummy2PipelineTask, "S"), 

205 (Dummy3PipelineTask, "T"), 

206 (Dummy4PipelineTask, "U"), 

207 ): 

208 config = task.ConfigClass() 

209 taskDef = TaskDef(get_full_type_name(task), config, task, label) 

210 tasks.append(taskDef) 

211 quantumSet = set() 

212 connections = taskDef.connections 

213 if connections.initInputs: 

214 initInputDSType = _makeDatasetType(connections.initInput) 

215 if (ref := init_dataset_refs.get(initInputDSType)) is not None: 

216 initRefs = [ref] 

217 else: 

218 initRefs = [ 

219 DatasetRef( 

220 initInputDSType, 

221 DataCoordinate.make_empty(universe), 

222 run=self.input_collection, 

223 ) 

224 ] 

225 initInputs[taskDef] = initRefs 

226 dataset_types.add(initInputDSType) 

227 else: 

228 initRefs = None 

229 if connections.initOutputs: 

230 initOutputDSType = _makeDatasetType(connections.initOutput) 

231 initRefs = [ 

232 DatasetRef(initOutputDSType, DataCoordinate.make_empty(universe), run=self.output_run) 

233 ] 

234 init_dataset_refs[initOutputDSType] = initRefs[0] 

235 initOutputs[taskDef] = initRefs 

236 dataset_types.add(initOutputDSType) 

237 inputDSType = _makeDatasetType(connections.input) 

238 dataset_types.add(inputDSType) 

239 outputDSType = _makeDatasetType(connections.output) 

240 dataset_types.add(outputDSType) 

241 for a, b in ((1, 2), (3, 4)): 

242 dataId = DataCoordinate.standardize({"A": a, "B": b}, universe=universe) 

243 if (ref := dataset_refs.get((inputDSType, dataId))) is None: 

244 inputRefs = [DatasetRef(inputDSType, dataId, run=self.input_collection)] 

245 else: 

246 inputRefs = [ref] 

247 outputRefs = [DatasetRef(outputDSType, dataId, run=self.output_run)] 

248 dataset_refs[(outputDSType, dataId)] = outputRefs[0] 

249 quantumSet.add( 

250 Quantum( 

251 taskName=task.__qualname__, 

252 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

253 taskClass=task, 

254 initInputs=initRefs, 

255 inputs={inputDSType: inputRefs}, 

256 outputs={outputDSType: outputRefs}, 

257 ) 

258 ) 

259 quantumMap[taskDef] = quantumSet 

260 self.tasks = tasks 

261 self.quantumMap = quantumMap 

262 self.packagesDSType = DatasetType("packages", universe.empty, storageClass="Packages") 

263 dataset_types.add(self.packagesDSType) 

264 globalInitOutputs = [ 

265 DatasetRef(self.packagesDSType, DataCoordinate.make_empty(universe), run=self.output_run) 

266 ] 

267 self.qGraph = QuantumGraph( 

268 quantumMap, 

269 metadata=METADATA, 

270 universe=universe, 

271 initInputs=initInputs, 

272 initOutputs=initOutputs, 

273 globalInitOutputs=globalInitOutputs, 

274 registryDatasetTypes=dataset_types, 

275 ) 

276 self.universe = universe 

277 self.num_dataset_types = len(dataset_types) 

278 

279 def testTaskGraph(self) -> None: 

280 for taskDef in self.quantumMap: 

281 self.assertIn(taskDef, self.qGraph.taskGraph) 

282 

283 def testGraph(self) -> None: 

284 graphSet = {q.quantum for q in self.qGraph.graph} 

285 for quantum in chain.from_iterable(self.quantumMap.values()): 

286 self.assertIn(quantum, graphSet) 

287 

288 def testGetQuantumNodeByNodeId(self) -> None: 

289 inputQuanta = tuple(self.qGraph.inputQuanta) 

290 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

291 self.assertEqual(node, inputQuanta[0]) 

292 wrongNode = uuid.uuid4() 

293 with self.assertRaises(KeyError): 

294 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

295 

296 def testPickle(self) -> None: 

297 stringify = pickle.dumps(self.qGraph) 

298 restore: QuantumGraph = pickle.loads(stringify) 

299 self.assertEqual(self.qGraph, restore) 

300 

301 def testInputQuanta(self) -> None: 

302 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

303 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs) 

304 

305 def testOutputQuanta(self) -> None: 

306 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

307 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs) 

308 

309 def testLength(self) -> None: 

310 self.assertEqual(len(self.qGraph), 2 * len(self.tasks)) 

311 

312 def testGetQuantaForTask(self) -> None: 

313 for task in self.tasks: 

314 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

315 

316 def testGetNumberOfQuantaForTask(self) -> None: 

317 for task in self.tasks: 

318 self.assertEqual(self.qGraph.getNumberOfQuantaForTask(task), len(self.quantumMap[task])) 

319 

320 def testGetNodesForTask(self) -> None: 

321 for task in self.tasks: 

322 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task) 

323 quanta_in_node = {n.quantum for n in nodes} 

324 self.assertEqual(quanta_in_node, self.quantumMap[task]) 

325 

326 def testFindTasksWithInput(self) -> None: 

327 self.assertEqual( 

328 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1] 

329 ) 

330 

331 def testFindTasksWithOutput(self) -> None: 

332 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

333 

334 def testTaskWithDSType(self) -> None: 

335 self.assertEqual( 

336 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2]) 

337 ) 

338 

339 def testFindTaskDefByName(self) -> None: 

340 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0]) 

341 

342 def testFindTaskDefByLabel(self) -> None: 

343 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0]) 

344 

345 def testFindQuantaWIthDSType(self) -> None: 

346 self.assertEqual( 

347 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]] 

348 ) 

349 

350 def testAllDatasetTypes(self) -> None: 

351 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

352 truth = set() 

353 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections): 

354 for connection in conClass.allConnections.values(): # type: ignore 

355 if not isinstance(connection, cT.InitOutput): 

356 truth.add(connection.name) 

357 self.assertEqual(allDatasetTypes, truth) 

358 

359 def testSubset(self) -> None: 

360 allNodes = list(self.qGraph) 

361 firstNode = allNodes[0] 

362 subset = self.qGraph.subset(firstNode) 

363 self.assertEqual(len(subset), 1) 

364 subsetList = list(subset) 

365 self.assertEqual(firstNode.quantum, subsetList[0].quantum) 

366 self.assertEqual(self.qGraph._buildId, subset._buildId) 

367 self.assertEqual(len(subset.globalInitOutputRefs()), 1) 

368 # Depending on which task was first the list can contain different 

369 # number of datasets. The first task can be either Dummy1 or Dummy4. 

370 num_types = {"R": 4, "U": 3} 

371 self.assertEqual(len(subset.registryDatasetTypes()), num_types[firstNode.taskDef.label]) 

372 

373 def testSubsetToConnected(self) -> None: 

374 # False because there are two quantum chains for two distinct sets of 

375 # dimensions 

376 self.assertFalse(self.qGraph.isConnected) 

377 

378 connectedGraphs = self.qGraph.subsetToConnected() 

379 self.assertEqual(len(connectedGraphs), 4) 

380 self.assertTrue(connectedGraphs[0].isConnected) 

381 self.assertTrue(connectedGraphs[1].isConnected) 

382 self.assertTrue(connectedGraphs[2].isConnected) 

383 self.assertTrue(connectedGraphs[3].isConnected) 

384 

385 # Split out task[3] because it is expected to be on its own 

386 for cg in connectedGraphs: 

387 if self.tasks[3] in cg.taskGraph: 

388 self.assertEqual(len(cg), 1) 

389 else: 

390 self.assertEqual(len(cg), 3) 

391 

392 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

393 

394 count = 0 

395 for node in self.qGraph: 

396 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

397 count += 1 

398 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

399 count += 1 

400 if connectedGraphs[2].checkQuantumInGraph(node.quantum): 

401 count += 1 

402 if connectedGraphs[3].checkQuantumInGraph(node.quantum): 

403 count += 1 

404 self.assertEqual(len(self.qGraph), count) 

405 

406 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs} 

407 for setLen, tskSet in taskSets.items(): 

408 if setLen == 3: 

409 self.assertEqual(set(self.tasks[:-1]), tskSet) 

410 elif setLen == 1: 

411 self.assertEqual({self.tasks[-1]}, tskSet) 

412 for cg in connectedGraphs: 

413 if len(cg.taskGraph) == 1: 

414 continue 

415 allNodes = list(cg) 

416 nodes = cg.determineInputsToQuantumNode(allNodes[1]) 

417 self.assertEqual({allNodes[0]}, nodes) 

418 nodes = cg.determineInputsToQuantumNode(allNodes[1]) 

419 self.assertEqual({allNodes[0]}, nodes) 

420 

421 def testDetermineOutputsOfQuantumNode(self) -> None: 

422 testNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

423 matchNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

424 connections = set() 

425 for node in testNodes: 

426 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node)) 

427 self.assertEqual(matchNodes, connections) 

428 

429 def testDetermineConnectionsOfQuantum(self) -> None: 

430 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

431 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2]) 

432 # outputs contain nodes tested for because it is a complete graph 

433 matchNodes |= set(testNodes) 

434 connections = set() 

435 for node in testNodes: 

436 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node)) 

437 self.assertEqual(matchNodes, connections) 

438 

439 def testDetermineAnsestorsOfQuantumNode(self) -> None: 

440 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

441 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

442 matchNodes |= set(testNodes) 

443 connections = set() 

444 for node in testNodes: 

445 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node)) 

446 self.assertEqual(matchNodes, connections) 

447 

448 def testFindCycle(self) -> None: 

449 self.assertFalse(self.qGraph.findCycle()) 

450 

451 def testSaveLoad(self) -> None: 

452 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: 

453 self.qGraph.save(tmpFile) 

454 tmpFile.seek(0) 

455 restore = QuantumGraph.load(tmpFile, self.universe) 

456 self.assertEqual(self.qGraph, restore) 

457 # Load in just one node 

458 tmpFile.seek(0) 

459 nodeId = [n.nodeId for n in self.qGraph][0] 

460 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,)) 

461 self.assertEqual(len(restoreSub), 1) 

462 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId)) 

463 self.assertEqual(len(restoreSub.globalInitOutputRefs()), 1) 

464 self.assertEqual(len(restoreSub.registryDatasetTypes()), self.num_dataset_types) 

465 # Check that InitInput and InitOutput refs are restored correctly. 

466 for taskDef in restore.iterTaskGraph(): 

467 if taskDef.label in ("S", "T"): 

468 refs = restore.initInputRefs(taskDef) 

469 self.assertIsNotNone(refs) 

470 assert refs is not None 

471 self.assertGreater(len(refs), 0) 

472 if taskDef.label in ("R", "S", "T"): 

473 refs = restore.initOutputRefs(taskDef) 

474 self.assertIsNotNone(refs) 

475 assert refs is not None 

476 self.assertGreater(len(refs), 0) 

477 

478 # Different universes. 

479 tmpFile.seek(0) 

480 different_config = self.config.copy() 

481 different_config["version"] = 1_000_000 

482 different_universe = DimensionUniverse(config=different_config) 

483 with self.assertLogs("lsst.daf.butler", "INFO"): 

484 QuantumGraph.load(tmpFile, different_universe) 

485 

486 different_config["namespace"] = "incompatible" 

487 different_universe = DimensionUniverse(config=different_config) 

488 print("Trying with uni ", different_universe) 

489 tmpFile.seek(0) 

490 with self.assertRaises(RuntimeError) as cm: 

491 QuantumGraph.load(tmpFile, different_universe) 

492 self.assertIn("not compatible with", str(cm.exception)) 

493 

494 def testSaveLoadUri(self) -> None: 

495 uri = None 

496 try: 

497 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: 

498 uri = tmpFile.name 

499 self.qGraph.saveUri(uri) 

500 restore = QuantumGraph.loadUri(uri) 

501 self.assertEqual(restore.metadata, METADATA) 

502 self.assertEqual(self.qGraph, restore) 

503 nodeNumberId = random.randint(0, len(self.qGraph) - 1) 

504 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId] 

505 restoreSub = QuantumGraph.loadUri( 

506 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId 

507 ) 

508 self.assertEqual(len(restoreSub), 1) 

509 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber)) 

510 # verify that more than one node works 

511 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

512 # ensure it is a different node number 

513 while nodeNumberId2 == nodeNumberId: 

514 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

515 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2] 

516 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) 

517 self.assertEqual(len(restoreSub), 2) 

518 self.assertEqual( 

519 set(restoreSub), 

520 { 

521 restore.getQuantumNodeByNodeId(nodeNumber), 

522 restore.getQuantumNodeByNodeId(nodeNumber2), 

523 }, 

524 ) 

525 # verify an error when requesting a non existant node number 

526 with self.assertRaises(ValueError): 

527 QuantumGraph.loadUri(uri, self.universe, nodes=(uuid.uuid4(),)) 

528 

529 # verify a graphID that does not match will be an error 

530 with self.assertRaises(ValueError): 

531 QuantumGraph.loadUri(uri, self.universe, graphID=BuildId("NOTRIGHT")) 

532 

533 except Exception as e: 

534 raise e 

535 finally: 

536 if uri is not None: 

537 os.remove(uri) 

538 

539 with self.assertRaises(TypeError): 

540 self.qGraph.saveUri("test.notgraph") 

541 

542 def testSaveLoadNoRegistryDatasetTypes(self) -> None: 

543 """Test for reading quantum that is missing registry dataset types. 

544 

545 This test depends on internals of QuantumGraph implementation, in 

546 particular that empty list of registry dataset types is not stored, 

547 which makes save file identical to the "old" format. 

548 """ 

549 # Reset the list, this is safe as QuantumGraph itself does not use it. 

550 self.qGraph._registryDatasetTypes = [] 

551 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: 

552 self.qGraph.save(tmpFile) 

553 tmpFile.seek(0) 

554 restore = QuantumGraph.load(tmpFile, self.universe) 

555 self.assertEqual(self.qGraph, restore) 

556 self.assertEqual(restore.registryDatasetTypes(), []) 

557 

558 def testContains(self) -> None: 

559 firstNode = next(iter(self.qGraph)) 

560 self.assertIn(firstNode, self.qGraph) 

561 

562 def testDimensionUniverseInSave(self) -> None: 

563 _, header = self.qGraph._buildSaveObject(returnHeader=True) 

564 # type ignore because buildSaveObject does not have method overload 

565 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore 

566 

567 def testUpdateRun(self) -> None: 

568 """Test for QuantumGraph.updateRun method.""" 

569 self.assertEqual(check_output_run(self.qGraph, self.output_run), []) 

570 output_refs = get_output_refs(self.qGraph) 

571 self.assertGreater(len(output_refs), 0) 

572 graph_id = self.qGraph.graphID 

573 

574 self.qGraph.updateRun("updated-run") 

575 self.assertEqual(check_output_run(self.qGraph, "updated-run"), []) 

576 self.assertEqual(self.qGraph.graphID, graph_id) 

577 output_refs2 = get_output_refs(self.qGraph) 

578 self.assertEqual(len(output_refs2), len(output_refs)) 

579 # All output dataset IDs must be updated. 

580 self.assertTrue(set(ref.id for ref in output_refs).isdisjoint(set(ref.id for ref in output_refs2))) 

581 

582 # Also update metadata. 

583 self.qGraph.updateRun("updated-run2", metadata_key="ouput_run") 

584 self.assertEqual(check_output_run(self.qGraph, "updated-run2"), []) 

585 self.assertEqual(self.qGraph.graphID, graph_id) 

586 assert self.qGraph.metadata is not None 

587 self.assertIn("ouput_run", self.qGraph.metadata) 

588 self.assertEqual(self.qGraph.metadata["ouput_run"], "updated-run2") 

589 

590 # Update graph ID. 

591 self.qGraph.updateRun("updated-run3", metadata_key="ouput_run", update_graph_id=True) 

592 self.assertEqual(check_output_run(self.qGraph, "updated-run3"), []) 

593 self.assertNotEqual(self.qGraph.graphID, graph_id) 

594 

595 

596class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

597 """Run file leak tests.""" 

598 

599 

600def setup_module(module) -> None: 

601 """Configure pytest.""" 

602 lsst.utils.tests.init() 

603 

604 

605if __name__ == "__main__": 

606 lsst.utils.tests.init() 

607 unittest.main()