Coverage for tests/test_quantumGraph.py: 19%
359 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-24 10:01 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-24 10:01 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28import os
29import pickle
30import random
31import tempfile
32import unittest
33import uuid
34from itertools import chain
36import lsst.pipe.base.connectionTypes as cT
37import lsst.utils.tests
38from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum
39from lsst.pex.config import Field
40from lsst.pipe.base import (
41 DatasetTypeName,
42 PipelineTask,
43 PipelineTaskConfig,
44 PipelineTaskConnections,
45 QuantumGraph,
46 TaskDef,
47)
48from lsst.pipe.base.graph.quantumNode import BuildId
49from lsst.pipe.base.tests.util import check_output_run, get_output_refs
50from lsst.utils.introspection import get_full_type_name
51from lsst.utils.packages import Packages
53METADATA = {"a": [1, 2, 3]}
56class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
57 """Dummy connections class #1."""
59 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
60 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
61 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
64class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
65 """Dummy config #1."""
67 conf1 = Field[int](default=1, doc="dummy config")
70class Dummy1PipelineTask(PipelineTask):
71 """Dummy pipeline task #1."""
73 ConfigClass = Dummy1Config
76class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
77 """Dummy connections class #2."""
79 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
80 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
81 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
82 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
85class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
86 """Dummy config #2."""
88 conf1 = Field[int](default=1, doc="dummy config")
91class Dummy2PipelineTask(PipelineTask):
92 """Dummy pipeline task #3."""
94 ConfigClass = Dummy2Config
97class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
98 """Dummy connections class #3."""
100 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
101 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a")
102 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
103 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
106class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
107 """Dummy config #3."""
109 conf1 = Field[int](default=1, doc="dummy config")
112class Dummy3PipelineTask(PipelineTask):
113 """Dummy pipeline task #3."""
115 ConfigClass = Dummy3Config
118# Test if a Task that does not interact with the other Tasks works fine in
119# the graph.
120class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")):
121 """Dummy connections class #4."""
123 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
124 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
127class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections):
128 """Dummy config #4."""
130 conf1 = Field[int](default=1, doc="dummy config")
133class Dummy4PipelineTask(PipelineTask):
134 """Dummy pipeline task #4."""
136 ConfigClass = Dummy4Config
139class QuantumGraphTestCase(unittest.TestCase):
140 """Tests the various functions of a quantum graph."""
142 input_collection = "inputs"
143 output_run = "run"
145 def setUp(self) -> None:
146 self.config = Config(
147 {
148 "version": 1,
149 "namespace": "pipe_base_test",
150 "skypix": {
151 "common": "htm7",
152 "htm": {
153 "class": "lsst.sphgeom.HtmPixelization",
154 "max_level": 24,
155 },
156 },
157 "elements": {
158 "A": {
159 "keys": [
160 {
161 "name": "id",
162 "type": "int",
163 }
164 ],
165 "storage": {
166 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
167 },
168 },
169 "B": {
170 "keys": [
171 {
172 "name": "id",
173 "type": "int",
174 }
175 ],
176 "storage": {
177 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
178 },
179 },
180 },
181 "packers": {},
182 }
183 )
184 universe = DimensionUniverse(config=self.config)
186 def _makeDatasetType(connection):
187 return DatasetType(
188 connection.name,
189 getattr(connection, "dimensions", ()),
190 storageClass=connection.storageClass,
191 universe=universe,
192 )
194 # need to make a mapping of TaskDef to set of quantum
195 quantumMap = {}
196 tasks = []
197 initInputs = {}
198 initOutputs = {}
199 dataset_types = set()
200 init_dataset_refs: dict[DatasetType, DatasetRef] = {}
201 dataset_refs: dict[tuple[DatasetType, DataCoordinate], DatasetRef] = {}
202 for task, label in (
203 (Dummy1PipelineTask, "R"),
204 (Dummy2PipelineTask, "S"),
205 (Dummy3PipelineTask, "T"),
206 (Dummy4PipelineTask, "U"),
207 ):
208 config = task.ConfigClass()
209 taskDef = TaskDef(get_full_type_name(task), config, task, label)
210 tasks.append(taskDef)
211 quantumSet = set()
212 connections = taskDef.connections
213 if connections.initInputs:
214 initInputDSType = _makeDatasetType(connections.initInput)
215 if (ref := init_dataset_refs.get(initInputDSType)) is not None:
216 initRefs = [ref]
217 else:
218 initRefs = [
219 DatasetRef(
220 initInputDSType,
221 DataCoordinate.make_empty(universe),
222 run=self.input_collection,
223 )
224 ]
225 initInputs[taskDef] = initRefs
226 dataset_types.add(initInputDSType)
227 else:
228 initRefs = None
229 if connections.initOutputs:
230 initOutputDSType = _makeDatasetType(connections.initOutput)
231 initRefs = [
232 DatasetRef(initOutputDSType, DataCoordinate.make_empty(universe), run=self.output_run)
233 ]
234 init_dataset_refs[initOutputDSType] = initRefs[0]
235 initOutputs[taskDef] = initRefs
236 dataset_types.add(initOutputDSType)
237 inputDSType = _makeDatasetType(connections.input)
238 dataset_types.add(inputDSType)
239 outputDSType = _makeDatasetType(connections.output)
240 dataset_types.add(outputDSType)
241 for a, b in ((1, 2), (3, 4)):
242 dataId = DataCoordinate.standardize({"A": a, "B": b}, universe=universe)
243 if (ref := dataset_refs.get((inputDSType, dataId))) is None:
244 inputRefs = [DatasetRef(inputDSType, dataId, run=self.input_collection)]
245 else:
246 inputRefs = [ref]
247 outputRefs = [DatasetRef(outputDSType, dataId, run=self.output_run)]
248 dataset_refs[(outputDSType, dataId)] = outputRefs[0]
249 quantumSet.add(
250 Quantum(
251 taskName=task.__qualname__,
252 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
253 taskClass=task,
254 initInputs=initRefs,
255 inputs={inputDSType: inputRefs},
256 outputs={outputDSType: outputRefs},
257 )
258 )
259 quantumMap[taskDef] = quantumSet
260 self.tasks = tasks
261 self.quanta_by_task_label = {task_def.label: set(quanta) for task_def, quanta in quantumMap.items()}
262 self.quantumMap = quantumMap
263 self.packagesDSType = DatasetType("packages", universe.empty, storageClass="Packages")
264 dataset_types.add(self.packagesDSType)
265 globalInitOutputs = [
266 DatasetRef(self.packagesDSType, DataCoordinate.make_empty(universe), run=self.output_run)
267 ]
268 self.qGraph = QuantumGraph(
269 quantumMap,
270 metadata=METADATA,
271 universe=universe,
272 initInputs=initInputs,
273 initOutputs=initOutputs,
274 globalInitOutputs=globalInitOutputs,
275 registryDatasetTypes=dataset_types,
276 )
277 self.universe = universe
278 self.num_dataset_types = len(dataset_types)
280 def testTaskGraph(self) -> None:
281 task_def_labels: list[str] = []
282 for taskDef in self.quantumMap:
283 task_def_labels.append(taskDef.label)
284 self.assertIn(taskDef, self.qGraph.taskGraph)
285 self.assertCountEqual(task_def_labels, self.qGraph.pipeline_graph.tasks.keys())
287 def testGraph(self) -> None:
288 graphSet = {q.quantum for q in self.qGraph.graph}
289 for quantum in chain.from_iterable(self.quantumMap.values()):
290 self.assertIn(quantum, graphSet)
292 def testGetQuantumNodeByNodeId(self) -> None:
293 inputQuanta = tuple(self.qGraph.inputQuanta)
294 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
295 self.assertEqual(node, inputQuanta[0])
296 wrongNode = uuid.uuid4()
297 with self.assertRaises(KeyError):
298 self.qGraph.getQuantumNodeByNodeId(wrongNode)
300 def testPickle(self) -> None:
301 stringify = pickle.dumps(self.qGraph)
302 restore: QuantumGraph = pickle.loads(stringify)
303 self.assertEqual(self.qGraph, restore)
305 def testInputQuanta(self) -> None:
306 inputs = {q.quantum for q in self.qGraph.inputQuanta}
307 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs)
309 def testOutputQuanta(self) -> None:
310 outputs = {q.quantum for q in self.qGraph.outputQuanta}
311 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs)
313 def testLength(self) -> None:
314 self.assertEqual(len(self.qGraph), 2 * len(self.tasks))
316 def testGetQuantaForTask(self) -> None:
317 for task in self.tasks:
318 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
320 def testGetNumberOfQuantaForTask(self) -> None:
321 for task in self.tasks:
322 self.assertEqual(self.qGraph.getNumberOfQuantaForTask(task), len(self.quantumMap[task]))
324 def testGetNodesForTask(self) -> None:
325 for task in self.tasks:
326 nodes = list(self.qGraph.getNodesForTask(task))
327 quanta_in_node = {n.quantum for n in nodes}
328 self.assertEqual(quanta_in_node, self.quantumMap[task])
329 for node in nodes:
330 self.assertEqual(
331 node.task_node.task_class_name,
332 self.qGraph.pipeline_graph.tasks[node.task_node.label].task_class_name,
333 )
335 def testFindTasksWithInput(self) -> None:
336 self.assertEqual(
337 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1]
338 )
340 def testFindTasksWithOutput(self) -> None:
341 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
343 def testTaskWithDSType(self) -> None:
344 self.assertEqual(
345 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2])
346 )
348 def testFindTaskDefByName(self) -> None:
349 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0])
351 def testFindTaskDefByLabel(self) -> None:
352 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0])
354 def testFindQuantaWIthDSType(self) -> None:
355 self.assertEqual(
356 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]]
357 )
359 def testAllDatasetTypes(self) -> None:
360 allDatasetTypes = set(self.qGraph.allDatasetTypes)
361 truth = set()
362 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections):
363 for connection in conClass.allConnections.values(): # type: ignore
364 if not isinstance(connection, cT.InitOutput):
365 truth.add(connection.name)
366 self.assertEqual(allDatasetTypes, truth)
368 def testSubset(self) -> None:
369 allNodes = list(self.qGraph)
370 firstNode = allNodes[0]
371 subset = self.qGraph.subset(firstNode)
372 self.assertEqual(len(subset), 1)
373 subsetList = list(subset)
374 self.assertEqual(firstNode.quantum, subsetList[0].quantum)
375 self.assertEqual(self.qGraph._buildId, subset._buildId)
376 self.assertEqual(len(subset.globalInitOutputRefs()), 1)
377 # Depending on which task was first the list can contain different
378 # number of datasets. The first task can be either Dummy1 or Dummy4.
379 num_types = {"R": 4, "U": 3}
380 self.assertEqual(len(subset.registryDatasetTypes()), num_types[firstNode.taskDef.label])
382 def testSubsetToConnected(self) -> None:
383 # False because there are two quantum chains for two distinct sets of
384 # dimensions
385 self.assertFalse(self.qGraph.isConnected)
387 connectedGraphs = self.qGraph.subsetToConnected()
388 self.assertEqual(len(connectedGraphs), 4)
389 self.assertTrue(connectedGraphs[0].isConnected)
390 self.assertTrue(connectedGraphs[1].isConnected)
391 self.assertTrue(connectedGraphs[2].isConnected)
392 self.assertTrue(connectedGraphs[3].isConnected)
394 # Split out task[3] because it is expected to be on its own
395 for cg in connectedGraphs:
396 if self.tasks[3] in cg.taskGraph:
397 self.assertEqual(len(cg), 1)
398 else:
399 self.assertEqual(len(cg), 3)
401 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
403 count = 0
404 for node in self.qGraph:
405 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
406 count += 1
407 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
408 count += 1
409 if connectedGraphs[2].checkQuantumInGraph(node.quantum):
410 count += 1
411 if connectedGraphs[3].checkQuantumInGraph(node.quantum):
412 count += 1
413 self.assertEqual(len(self.qGraph), count)
415 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs}
416 for setLen, tskSet in taskSets.items():
417 if setLen == 3:
418 self.assertEqual(set(self.tasks[:-1]), tskSet)
419 elif setLen == 1:
420 self.assertEqual({self.tasks[-1]}, tskSet)
421 for cg in connectedGraphs:
422 if len(cg.taskGraph) == 1:
423 continue
424 allNodes = list(cg)
425 nodes = cg.determineInputsToQuantumNode(allNodes[1])
426 self.assertEqual({allNodes[0]}, nodes)
427 nodes = cg.determineInputsToQuantumNode(allNodes[1])
428 self.assertEqual({allNodes[0]}, nodes)
430 def testDetermineOutputsOfQuantumNode(self) -> None:
431 testNodes = self.qGraph.getNodesForTask(self.tasks[0])
432 matchNodes = self.qGraph.getNodesForTask(self.tasks[1])
433 connections = set()
434 for node in testNodes:
435 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node))
436 self.assertEqual(matchNodes, connections)
438 def testDetermineConnectionsOfQuantum(self) -> None:
439 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
440 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2])
441 # outputs contain nodes tested for because it is a complete graph
442 matchNodes |= set(testNodes)
443 connections = set()
444 for node in testNodes:
445 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node))
446 self.assertEqual(matchNodes, connections)
448 def testDetermineAnsestorsOfQuantumNode(self) -> None:
449 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
450 matchNodes = self.qGraph.getNodesForTask(self.tasks[0])
451 matchNodes |= set(testNodes)
452 connections = set()
453 for node in testNodes:
454 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node))
455 self.assertEqual(matchNodes, connections)
457 def testFindCycle(self) -> None:
458 self.assertFalse(self.qGraph.findCycle())
460 def testSaveLoad(self) -> None:
461 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile:
462 self.qGraph.save(tmpFile)
463 tmpFile.seek(0)
464 restore = QuantumGraph.load(tmpFile, self.universe)
465 self.assertEqual(self.qGraph, restore)
466 # Load in just one node
467 tmpFile.seek(0)
468 nodeId = [n.nodeId for n in self.qGraph][0]
469 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,))
470 self.assertEqual(len(restoreSub), 1)
471 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
472 self.assertEqual(len(restoreSub.globalInitOutputRefs()), 1)
473 self.assertEqual(len(restoreSub.registryDatasetTypes()), self.num_dataset_types)
474 # Check that InitInput and InitOutput refs are restored correctly.
475 for taskDef in restore.iterTaskGraph():
476 if taskDef.label in ("S", "T"):
477 refs = restore.initInputRefs(taskDef)
478 self.assertIsNotNone(refs)
479 assert refs is not None
480 self.assertGreater(len(refs), 0)
481 if taskDef.label in ("R", "S", "T"):
482 refs = restore.initOutputRefs(taskDef)
483 self.assertIsNotNone(refs)
484 assert refs is not None
485 self.assertGreater(len(refs), 0)
487 # Different universes.
488 tmpFile.seek(0)
489 different_config = self.config.copy()
490 different_config["version"] = 1_000_000
491 different_universe = DimensionUniverse(config=different_config)
492 with self.assertLogs("lsst.daf.butler", "INFO"):
493 QuantumGraph.load(tmpFile, different_universe)
495 different_config["namespace"] = "incompatible"
496 different_universe = DimensionUniverse(config=different_config)
497 print("Trying with uni ", different_universe)
498 tmpFile.seek(0)
499 with self.assertRaises(RuntimeError) as cm:
500 QuantumGraph.load(tmpFile, different_universe)
501 self.assertIn("not compatible with", str(cm.exception))
503 def testSaveLoadUri(self) -> None:
504 uri = None
505 try:
506 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile:
507 uri = tmpFile.name
508 self.qGraph.saveUri(uri)
509 restore = QuantumGraph.loadUri(uri)
510 self.assertEqual(restore.metadata, self.qGraph.metadata)
511 self.assertEqual(self.qGraph, restore)
512 nodeNumberId = random.randint(0, len(self.qGraph) - 1)
513 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId]
514 restoreSub = QuantumGraph.loadUri(
515 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId
516 )
517 self.assertEqual(len(restoreSub), 1)
518 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber))
519 # verify that more than one node works
520 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
521 # ensure it is a different node number
522 while nodeNumberId2 == nodeNumberId:
523 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
524 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2]
525 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2))
526 self.assertEqual(len(restoreSub), 2)
527 self.assertEqual(
528 set(restoreSub),
529 {
530 restore.getQuantumNodeByNodeId(nodeNumber),
531 restore.getQuantumNodeByNodeId(nodeNumber2),
532 },
533 )
534 # verify an error when requesting a non existant node number
535 with self.assertRaises(ValueError):
536 QuantumGraph.loadUri(uri, self.universe, nodes=(uuid.uuid4(),))
538 # verify a graphID that does not match will be an error
539 with self.assertRaises(ValueError):
540 QuantumGraph.loadUri(uri, self.universe, graphID=BuildId("NOTRIGHT"))
542 except Exception as e:
543 raise e
544 finally:
545 if uri is not None:
546 os.remove(uri)
548 with self.assertRaises(TypeError):
549 self.qGraph.saveUri("test.notgraph")
551 def testSaveLoadNoRegistryDatasetTypes(self) -> None:
552 """Test for reading quantum that is missing registry dataset types.
554 This test depends on internals of QuantumGraph implementation, in
555 particular that empty list of registry dataset types is not stored,
556 which makes save file identical to the "old" format.
557 """
558 # Reset the list, this is safe as QuantumGraph itself does not use it.
559 self.qGraph._registryDatasetTypes = []
560 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile:
561 self.qGraph.save(tmpFile)
562 tmpFile.seek(0)
563 restore = QuantumGraph.load(tmpFile, self.universe)
564 self.assertEqual(self.qGraph, restore)
565 self.assertEqual(restore.registryDatasetTypes(), [])
567 def testContains(self) -> None:
568 firstNode = next(iter(self.qGraph))
569 self.assertIn(firstNode, self.qGraph)
571 def testDimensionUniverseInSave(self) -> None:
572 _, header = self.qGraph._buildSaveObject(returnHeader=True)
573 # type ignore because buildSaveObject does not have method overload
574 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore
576 def testUpdateRun(self) -> None:
577 """Test for QuantumGraph.updateRun method."""
578 self.assertEqual(check_output_run(self.qGraph, self.output_run), [])
579 output_refs = get_output_refs(self.qGraph)
580 self.assertGreater(len(output_refs), 0)
581 graph_id = self.qGraph.graphID
583 self.qGraph.updateRun("updated-run")
584 self.assertEqual(check_output_run(self.qGraph, "updated-run"), [])
585 self.assertEqual(self.qGraph.graphID, graph_id)
586 output_refs2 = get_output_refs(self.qGraph)
587 self.assertEqual(len(output_refs2), len(output_refs))
588 # All output dataset IDs must be updated.
589 self.assertTrue(set(ref.id for ref in output_refs).isdisjoint(set(ref.id for ref in output_refs2)))
591 # Also update metadata.
592 self.qGraph.updateRun("updated-run2", metadata_key="output_run")
593 self.assertEqual(check_output_run(self.qGraph, "updated-run2"), [])
594 self.assertEqual(self.qGraph.graphID, graph_id)
595 assert self.qGraph.metadata is not None
596 self.assertIn("output_run", self.qGraph.metadata)
597 self.assertEqual(self.qGraph.metadata["output_run"], "updated-run2")
599 # Update graph ID.
600 self.qGraph.updateRun("updated-run3", metadata_key="output_run", update_graph_id=True)
601 self.assertEqual(check_output_run(self.qGraph, "updated-run3"), [])
602 self.assertNotEqual(self.qGraph.graphID, graph_id)
604 def testMetadataPackage(self) -> None:
605 """Test package versions added to QuantumGraph metadata."""
606 packages = Packages.fromSystem()
607 self.assertFalse(self.qGraph.metadata["packages"].difference(packages))
609 def test_get_task_quanta(self) -> None:
610 for task_label in self.qGraph.pipeline_graph.tasks.keys():
611 quanta = self.qGraph.get_task_quanta(task_label)
612 self.assertCountEqual(quanta.values(), self.quanta_by_task_label[task_label])
614 def testGetSummary(self) -> None:
615 """Test for QuantumGraph.getSummary method."""
616 summary = self.qGraph.getSummary()
617 self.assertEqual(self.qGraph.graphID, summary.graphID)
618 self.assertEqual(len(summary.qgraphTaskSummaries), len(self.qGraph.taskGraph))
621class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
622 """Run file leak tests."""
625def setup_module(module) -> None:
626 """Configure pytest."""
627 lsst.utils.tests.init()
630if __name__ == "__main__":
631 lsst.utils.tests.init()
632 unittest.main()