Coverage for tests/test_quantumGraph.py: 19%

346 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-27 02:40 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28import os 

29import pickle 

30import random 

31import tempfile 

32import unittest 

33import uuid 

34from collections.abc import Iterable 

35from itertools import chain 

36 

37import lsst.pipe.base.connectionTypes as cT 

38import lsst.utils.tests 

39from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

40from lsst.pex.config import Field 

41from lsst.pipe.base import ( 

42 DatasetTypeName, 

43 PipelineTask, 

44 PipelineTaskConfig, 

45 PipelineTaskConnections, 

46 QuantumGraph, 

47 TaskDef, 

48) 

49from lsst.pipe.base.graph.quantumNode import BuildId, QuantumNode 

50from lsst.pipe.base.tests.util import check_output_run, get_output_refs 

51from lsst.utils.introspection import get_full_type_name 

52from lsst.utils.packages import Packages 

53 

54METADATA = {"a": [1, 2, 3]} 

55 

56 

57class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

58 """Dummy connections class #1.""" 

59 

60 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

61 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

62 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

63 

64 

65class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

66 """Dummy config #1.""" 

67 

68 conf1 = Field[int](default=1, doc="dummy config") 

69 

70 

71class Dummy1PipelineTask(PipelineTask): 

72 """Dummy pipeline task #1.""" 

73 

74 ConfigClass = Dummy1Config 

75 

76 

77class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

78 """Dummy connections class #2.""" 

79 

80 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

81 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

82 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

83 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

84 

85 

86class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

87 """Dummy config #2.""" 

88 

89 conf1 = Field[int](default=1, doc="dummy config") 

90 

91 

92class Dummy2PipelineTask(PipelineTask): 

93 """Dummy pipeline task #3.""" 

94 

95 ConfigClass = Dummy2Config 

96 

97 

98class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

99 """Dummy connections class #3.""" 

100 

101 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

102 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a") 

103 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

104 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

105 

106 

107class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

108 """Dummy config #3.""" 

109 

110 conf1 = Field[int](default=1, doc="dummy config") 

111 

112 

113class Dummy3PipelineTask(PipelineTask): 

114 """Dummy pipeline task #3.""" 

115 

116 ConfigClass = Dummy3Config 

117 

118 

119# Test if a Task that does not interact with the other Tasks works fine in 

120# the graph. 

121class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")): 

122 """Dummy connections class #4.""" 

123 

124 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

125 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

126 

127 

128class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections): 

129 """Dummy config #4.""" 

130 

131 conf1 = Field[int](default=1, doc="dummy config") 

132 

133 

134class Dummy4PipelineTask(PipelineTask): 

135 """Dummy pipeline task #4.""" 

136 

137 ConfigClass = Dummy4Config 

138 

139 

140class QuantumGraphTestCase(unittest.TestCase): 

141 """Tests the various functions of a quantum graph.""" 

142 

143 input_collection = "inputs" 

144 output_run = "run" 

145 

146 def setUp(self) -> None: 

147 self.config = Config( 

148 { 

149 "version": 1, 

150 "namespace": "pipe_base_test", 

151 "skypix": { 

152 "common": "htm7", 

153 "htm": { 

154 "class": "lsst.sphgeom.HtmPixelization", 

155 "max_level": 24, 

156 }, 

157 }, 

158 "elements": { 

159 "A": { 

160 "keys": [ 

161 { 

162 "name": "id", 

163 "type": "int", 

164 } 

165 ], 

166 "storage": { 

167 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

168 }, 

169 }, 

170 "B": { 

171 "keys": [ 

172 { 

173 "name": "id", 

174 "type": "int", 

175 } 

176 ], 

177 "storage": { 

178 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

179 }, 

180 }, 

181 }, 

182 "packers": {}, 

183 } 

184 ) 

185 universe = DimensionUniverse(config=self.config) 

186 

187 def _makeDatasetType(connection): 

188 return DatasetType( 

189 connection.name, 

190 getattr(connection, "dimensions", ()), 

191 storageClass=connection.storageClass, 

192 universe=universe, 

193 ) 

194 

195 # need to make a mapping of TaskDef to set of quantum 

196 quantumMap = {} 

197 tasks = [] 

198 initInputs = {} 

199 initOutputs = {} 

200 dataset_types = set() 

201 init_dataset_refs: dict[DatasetType, DatasetRef] = {} 

202 dataset_refs: dict[tuple[DatasetType, DataCoordinate], DatasetRef] = {} 

203 for task, label in ( 

204 (Dummy1PipelineTask, "R"), 

205 (Dummy2PipelineTask, "S"), 

206 (Dummy3PipelineTask, "T"), 

207 (Dummy4PipelineTask, "U"), 

208 ): 

209 config = task.ConfigClass() 

210 taskDef = TaskDef(get_full_type_name(task), config, task, label) 

211 tasks.append(taskDef) 

212 quantumSet = set() 

213 connections = taskDef.connections 

214 if connections.initInputs: 

215 initInputDSType = _makeDatasetType(connections.initInput) 

216 if (ref := init_dataset_refs.get(initInputDSType)) is not None: 

217 initRefs = [ref] 

218 else: 

219 initRefs = [ 

220 DatasetRef( 

221 initInputDSType, 

222 DataCoordinate.make_empty(universe), 

223 run=self.input_collection, 

224 ) 

225 ] 

226 initInputs[taskDef] = initRefs 

227 dataset_types.add(initInputDSType) 

228 else: 

229 initRefs = None 

230 if connections.initOutputs: 

231 initOutputDSType = _makeDatasetType(connections.initOutput) 

232 initRefs = [ 

233 DatasetRef(initOutputDSType, DataCoordinate.make_empty(universe), run=self.output_run) 

234 ] 

235 init_dataset_refs[initOutputDSType] = initRefs[0] 

236 initOutputs[taskDef] = initRefs 

237 dataset_types.add(initOutputDSType) 

238 inputDSType = _makeDatasetType(connections.input) 

239 dataset_types.add(inputDSType) 

240 outputDSType = _makeDatasetType(connections.output) 

241 dataset_types.add(outputDSType) 

242 for a, b in ((1, 2), (3, 4)): 

243 dataId = DataCoordinate.standardize({"A": a, "B": b}, universe=universe) 

244 if (ref := dataset_refs.get((inputDSType, dataId))) is None: 

245 inputRefs = [DatasetRef(inputDSType, dataId, run=self.input_collection)] 

246 else: 

247 inputRefs = [ref] 

248 outputRefs = [DatasetRef(outputDSType, dataId, run=self.output_run)] 

249 dataset_refs[(outputDSType, dataId)] = outputRefs[0] 

250 quantumSet.add( 

251 Quantum( 

252 taskName=task.__qualname__, 

253 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

254 taskClass=task, 

255 initInputs=initRefs, 

256 inputs={inputDSType: inputRefs}, 

257 outputs={outputDSType: outputRefs}, 

258 ) 

259 ) 

260 quantumMap[taskDef] = quantumSet 

261 self.tasks = tasks 

262 self.quantumMap = quantumMap 

263 self.packagesDSType = DatasetType("packages", universe.empty, storageClass="Packages") 

264 dataset_types.add(self.packagesDSType) 

265 globalInitOutputs = [ 

266 DatasetRef(self.packagesDSType, DataCoordinate.make_empty(universe), run=self.output_run) 

267 ] 

268 self.qGraph = QuantumGraph( 

269 quantumMap, 

270 metadata=METADATA, 

271 universe=universe, 

272 initInputs=initInputs, 

273 initOutputs=initOutputs, 

274 globalInitOutputs=globalInitOutputs, 

275 registryDatasetTypes=dataset_types, 

276 ) 

277 self.universe = universe 

278 self.num_dataset_types = len(dataset_types) 

279 

280 def testTaskGraph(self) -> None: 

281 for taskDef in self.quantumMap: 

282 self.assertIn(taskDef, self.qGraph.taskGraph) 

283 

284 def testGraph(self) -> None: 

285 graphSet = {q.quantum for q in self.qGraph.graph} 

286 for quantum in chain.from_iterable(self.quantumMap.values()): 

287 self.assertIn(quantum, graphSet) 

288 

289 def testGetQuantumNodeByNodeId(self) -> None: 

290 inputQuanta = tuple(self.qGraph.inputQuanta) 

291 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

292 self.assertEqual(node, inputQuanta[0]) 

293 wrongNode = uuid.uuid4() 

294 with self.assertRaises(KeyError): 

295 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

296 

297 def testPickle(self) -> None: 

298 stringify = pickle.dumps(self.qGraph) 

299 restore: QuantumGraph = pickle.loads(stringify) 

300 self.assertEqual(self.qGraph, restore) 

301 

302 def testInputQuanta(self) -> None: 

303 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

304 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs) 

305 

306 def testOutputQuanta(self) -> None: 

307 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

308 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs) 

309 

310 def testLength(self) -> None: 

311 self.assertEqual(len(self.qGraph), 2 * len(self.tasks)) 

312 

313 def testGetQuantaForTask(self) -> None: 

314 for task in self.tasks: 

315 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

316 

317 def testGetNumberOfQuantaForTask(self) -> None: 

318 for task in self.tasks: 

319 self.assertEqual(self.qGraph.getNumberOfQuantaForTask(task), len(self.quantumMap[task])) 

320 

321 def testGetNodesForTask(self) -> None: 

322 for task in self.tasks: 

323 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task) 

324 quanta_in_node = {n.quantum for n in nodes} 

325 self.assertEqual(quanta_in_node, self.quantumMap[task]) 

326 

327 def testFindTasksWithInput(self) -> None: 

328 self.assertEqual( 

329 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1] 

330 ) 

331 

332 def testFindTasksWithOutput(self) -> None: 

333 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

334 

335 def testTaskWithDSType(self) -> None: 

336 self.assertEqual( 

337 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2]) 

338 ) 

339 

340 def testFindTaskDefByName(self) -> None: 

341 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0]) 

342 

343 def testFindTaskDefByLabel(self) -> None: 

344 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0]) 

345 

346 def testFindQuantaWIthDSType(self) -> None: 

347 self.assertEqual( 

348 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]] 

349 ) 

350 

351 def testAllDatasetTypes(self) -> None: 

352 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

353 truth = set() 

354 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections): 

355 for connection in conClass.allConnections.values(): # type: ignore 

356 if not isinstance(connection, cT.InitOutput): 

357 truth.add(connection.name) 

358 self.assertEqual(allDatasetTypes, truth) 

359 

360 def testSubset(self) -> None: 

361 allNodes = list(self.qGraph) 

362 firstNode = allNodes[0] 

363 subset = self.qGraph.subset(firstNode) 

364 self.assertEqual(len(subset), 1) 

365 subsetList = list(subset) 

366 self.assertEqual(firstNode.quantum, subsetList[0].quantum) 

367 self.assertEqual(self.qGraph._buildId, subset._buildId) 

368 self.assertEqual(len(subset.globalInitOutputRefs()), 1) 

369 # Depending on which task was first the list can contain different 

370 # number of datasets. The first task can be either Dummy1 or Dummy4. 

371 num_types = {"R": 4, "U": 3} 

372 self.assertEqual(len(subset.registryDatasetTypes()), num_types[firstNode.taskDef.label]) 

373 

374 def testSubsetToConnected(self) -> None: 

375 # False because there are two quantum chains for two distinct sets of 

376 # dimensions 

377 self.assertFalse(self.qGraph.isConnected) 

378 

379 connectedGraphs = self.qGraph.subsetToConnected() 

380 self.assertEqual(len(connectedGraphs), 4) 

381 self.assertTrue(connectedGraphs[0].isConnected) 

382 self.assertTrue(connectedGraphs[1].isConnected) 

383 self.assertTrue(connectedGraphs[2].isConnected) 

384 self.assertTrue(connectedGraphs[3].isConnected) 

385 

386 # Split out task[3] because it is expected to be on its own 

387 for cg in connectedGraphs: 

388 if self.tasks[3] in cg.taskGraph: 

389 self.assertEqual(len(cg), 1) 

390 else: 

391 self.assertEqual(len(cg), 3) 

392 

393 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

394 

395 count = 0 

396 for node in self.qGraph: 

397 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

398 count += 1 

399 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

400 count += 1 

401 if connectedGraphs[2].checkQuantumInGraph(node.quantum): 

402 count += 1 

403 if connectedGraphs[3].checkQuantumInGraph(node.quantum): 

404 count += 1 

405 self.assertEqual(len(self.qGraph), count) 

406 

407 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs} 

408 for setLen, tskSet in taskSets.items(): 

409 if setLen == 3: 

410 self.assertEqual(set(self.tasks[:-1]), tskSet) 

411 elif setLen == 1: 

412 self.assertEqual({self.tasks[-1]}, tskSet) 

413 for cg in connectedGraphs: 

414 if len(cg.taskGraph) == 1: 

415 continue 

416 allNodes = list(cg) 

417 nodes = cg.determineInputsToQuantumNode(allNodes[1]) 

418 self.assertEqual({allNodes[0]}, nodes) 

419 nodes = cg.determineInputsToQuantumNode(allNodes[1]) 

420 self.assertEqual({allNodes[0]}, nodes) 

421 

422 def testDetermineOutputsOfQuantumNode(self) -> None: 

423 testNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

424 matchNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

425 connections = set() 

426 for node in testNodes: 

427 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node)) 

428 self.assertEqual(matchNodes, connections) 

429 

430 def testDetermineConnectionsOfQuantum(self) -> None: 

431 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

432 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2]) 

433 # outputs contain nodes tested for because it is a complete graph 

434 matchNodes |= set(testNodes) 

435 connections = set() 

436 for node in testNodes: 

437 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node)) 

438 self.assertEqual(matchNodes, connections) 

439 

440 def testDetermineAnsestorsOfQuantumNode(self) -> None: 

441 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

442 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

443 matchNodes |= set(testNodes) 

444 connections = set() 

445 for node in testNodes: 

446 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node)) 

447 self.assertEqual(matchNodes, connections) 

448 

449 def testFindCycle(self) -> None: 

450 self.assertFalse(self.qGraph.findCycle()) 

451 

452 def testSaveLoad(self) -> None: 

453 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: 

454 self.qGraph.save(tmpFile) 

455 tmpFile.seek(0) 

456 restore = QuantumGraph.load(tmpFile, self.universe) 

457 self.assertEqual(self.qGraph, restore) 

458 # Load in just one node 

459 tmpFile.seek(0) 

460 nodeId = [n.nodeId for n in self.qGraph][0] 

461 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,)) 

462 self.assertEqual(len(restoreSub), 1) 

463 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId)) 

464 self.assertEqual(len(restoreSub.globalInitOutputRefs()), 1) 

465 self.assertEqual(len(restoreSub.registryDatasetTypes()), self.num_dataset_types) 

466 # Check that InitInput and InitOutput refs are restored correctly. 

467 for taskDef in restore.iterTaskGraph(): 

468 if taskDef.label in ("S", "T"): 

469 refs = restore.initInputRefs(taskDef) 

470 self.assertIsNotNone(refs) 

471 assert refs is not None 

472 self.assertGreater(len(refs), 0) 

473 if taskDef.label in ("R", "S", "T"): 

474 refs = restore.initOutputRefs(taskDef) 

475 self.assertIsNotNone(refs) 

476 assert refs is not None 

477 self.assertGreater(len(refs), 0) 

478 

479 # Different universes. 

480 tmpFile.seek(0) 

481 different_config = self.config.copy() 

482 different_config["version"] = 1_000_000 

483 different_universe = DimensionUniverse(config=different_config) 

484 with self.assertLogs("lsst.daf.butler", "INFO"): 

485 QuantumGraph.load(tmpFile, different_universe) 

486 

487 different_config["namespace"] = "incompatible" 

488 different_universe = DimensionUniverse(config=different_config) 

489 print("Trying with uni ", different_universe) 

490 tmpFile.seek(0) 

491 with self.assertRaises(RuntimeError) as cm: 

492 QuantumGraph.load(tmpFile, different_universe) 

493 self.assertIn("not compatible with", str(cm.exception)) 

494 

495 def testSaveLoadUri(self) -> None: 

496 uri = None 

497 try: 

498 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: 

499 uri = tmpFile.name 

500 self.qGraph.saveUri(uri) 

501 restore = QuantumGraph.loadUri(uri) 

502 self.assertEqual(restore.metadata, self.qGraph.metadata) 

503 self.assertEqual(self.qGraph, restore) 

504 nodeNumberId = random.randint(0, len(self.qGraph) - 1) 

505 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId] 

506 restoreSub = QuantumGraph.loadUri( 

507 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId 

508 ) 

509 self.assertEqual(len(restoreSub), 1) 

510 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber)) 

511 # verify that more than one node works 

512 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

513 # ensure it is a different node number 

514 while nodeNumberId2 == nodeNumberId: 

515 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

516 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2] 

517 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) 

518 self.assertEqual(len(restoreSub), 2) 

519 self.assertEqual( 

520 set(restoreSub), 

521 { 

522 restore.getQuantumNodeByNodeId(nodeNumber), 

523 restore.getQuantumNodeByNodeId(nodeNumber2), 

524 }, 

525 ) 

526 # verify an error when requesting a non existant node number 

527 with self.assertRaises(ValueError): 

528 QuantumGraph.loadUri(uri, self.universe, nodes=(uuid.uuid4(),)) 

529 

530 # verify a graphID that does not match will be an error 

531 with self.assertRaises(ValueError): 

532 QuantumGraph.loadUri(uri, self.universe, graphID=BuildId("NOTRIGHT")) 

533 

534 except Exception as e: 

535 raise e 

536 finally: 

537 if uri is not None: 

538 os.remove(uri) 

539 

540 with self.assertRaises(TypeError): 

541 self.qGraph.saveUri("test.notgraph") 

542 

543 def testSaveLoadNoRegistryDatasetTypes(self) -> None: 

544 """Test for reading quantum that is missing registry dataset types. 

545 

546 This test depends on internals of QuantumGraph implementation, in 

547 particular that empty list of registry dataset types is not stored, 

548 which makes save file identical to the "old" format. 

549 """ 

550 # Reset the list, this is safe as QuantumGraph itself does not use it. 

551 self.qGraph._registryDatasetTypes = [] 

552 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: 

553 self.qGraph.save(tmpFile) 

554 tmpFile.seek(0) 

555 restore = QuantumGraph.load(tmpFile, self.universe) 

556 self.assertEqual(self.qGraph, restore) 

557 self.assertEqual(restore.registryDatasetTypes(), []) 

558 

559 def testContains(self) -> None: 

560 firstNode = next(iter(self.qGraph)) 

561 self.assertIn(firstNode, self.qGraph) 

562 

563 def testDimensionUniverseInSave(self) -> None: 

564 _, header = self.qGraph._buildSaveObject(returnHeader=True) 

565 # type ignore because buildSaveObject does not have method overload 

566 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore 

567 

568 def testUpdateRun(self) -> None: 

569 """Test for QuantumGraph.updateRun method.""" 

570 self.assertEqual(check_output_run(self.qGraph, self.output_run), []) 

571 output_refs = get_output_refs(self.qGraph) 

572 self.assertGreater(len(output_refs), 0) 

573 graph_id = self.qGraph.graphID 

574 

575 self.qGraph.updateRun("updated-run") 

576 self.assertEqual(check_output_run(self.qGraph, "updated-run"), []) 

577 self.assertEqual(self.qGraph.graphID, graph_id) 

578 output_refs2 = get_output_refs(self.qGraph) 

579 self.assertEqual(len(output_refs2), len(output_refs)) 

580 # All output dataset IDs must be updated. 

581 self.assertTrue(set(ref.id for ref in output_refs).isdisjoint(set(ref.id for ref in output_refs2))) 

582 

583 # Also update metadata. 

584 self.qGraph.updateRun("updated-run2", metadata_key="output_run") 

585 self.assertEqual(check_output_run(self.qGraph, "updated-run2"), []) 

586 self.assertEqual(self.qGraph.graphID, graph_id) 

587 assert self.qGraph.metadata is not None 

588 self.assertIn("output_run", self.qGraph.metadata) 

589 self.assertEqual(self.qGraph.metadata["output_run"], "updated-run2") 

590 

591 # Update graph ID. 

592 self.qGraph.updateRun("updated-run3", metadata_key="output_run", update_graph_id=True) 

593 self.assertEqual(check_output_run(self.qGraph, "updated-run3"), []) 

594 self.assertNotEqual(self.qGraph.graphID, graph_id) 

595 

596 def testMetadataPackage(self) -> None: 

597 """Test package versions added to QuantumGraph metadata.""" 

598 packages = Packages.fromSystem() 

599 self.assertEqual(self.qGraph.metadata["packages"], packages) 

600 

601 

602class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

603 """Run file leak tests.""" 

604 

605 

606def setup_module(module) -> None: 

607 """Configure pytest.""" 

608 lsst.utils.tests.init() 

609 

610 

611if __name__ == "__main__": 

612 lsst.utils.tests.init() 

613 unittest.main()