Coverage for tests/test_quantumGraph.py: 19%
338 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-11 02:00 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-11 02:00 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import os
23import pickle
24import random
25import tempfile
26import unittest
27import uuid
28from itertools import chain
29from typing import Iterable
31import lsst.pipe.base.connectionTypes as cT
32import lsst.utils.tests
33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum
34from lsst.pex.config import Field
35from lsst.pipe.base import (
36 DatasetTypeName,
37 PipelineTask,
38 PipelineTaskConfig,
39 PipelineTaskConnections,
40 QuantumGraph,
41 TaskDef,
42)
43from lsst.pipe.base.graph.quantumNode import BuildId, QuantumNode
44from lsst.pipe.base.tests.util import check_output_run
45from lsst.utils.introspection import get_full_type_name
47METADATA = {"a": [1, 2, 3]}
50class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
51 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
52 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
53 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
56class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
57 conf1 = Field[int](default=1, doc="dummy config")
60class Dummy1PipelineTask(PipelineTask):
61 ConfigClass = Dummy1Config
64class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
65 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
66 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
67 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
68 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
71class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
72 conf1 = Field[int](default=1, doc="dummy config")
75class Dummy2PipelineTask(PipelineTask):
76 ConfigClass = Dummy2Config
79class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
80 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
81 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a")
82 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
83 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
86class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
87 conf1 = Field[int](default=1, doc="dummy config")
90class Dummy3PipelineTask(PipelineTask):
91 ConfigClass = Dummy3Config
94# Test if a Task that does not interact with the other Tasks works fine in
95# the graph.
96class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")):
97 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
98 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
101class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections):
102 conf1 = Field[int](default=1, doc="dummy config")
105class Dummy4PipelineTask(PipelineTask):
106 ConfigClass = Dummy4Config
109class QuantumGraphTestCase(unittest.TestCase):
110 """Tests the various functions of a quantum graph"""
112 input_collection = "inputs"
113 output_run = "run"
115 def setUp(self) -> None:
116 self.config = Config(
117 {
118 "version": 1,
119 "namespace": "pipe_base_test",
120 "skypix": {
121 "common": "htm7",
122 "htm": {
123 "class": "lsst.sphgeom.HtmPixelization",
124 "max_level": 24,
125 },
126 },
127 "elements": {
128 "A": {
129 "keys": [
130 {
131 "name": "id",
132 "type": "int",
133 }
134 ],
135 "storage": {
136 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
137 },
138 },
139 "B": {
140 "keys": [
141 {
142 "name": "id",
143 "type": "int",
144 }
145 ],
146 "storage": {
147 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
148 },
149 },
150 },
151 "packers": {},
152 }
153 )
154 universe = DimensionUniverse(config=self.config)
156 def _makeDatasetType(connection):
157 return DatasetType(
158 connection.name,
159 getattr(connection, "dimensions", ()),
160 storageClass=connection.storageClass,
161 universe=universe,
162 )
164 # need to make a mapping of TaskDef to set of quantum
165 quantumMap = {}
166 tasks = []
167 initInputs = {}
168 initOutputs = {}
169 dataset_types = set()
170 init_dataset_refs: dict[DatasetType, DatasetRef] = {}
171 dataset_refs: dict[tuple[DatasetType, DataCoordinate], DatasetRef] = {}
172 for task, label in (
173 (Dummy1PipelineTask, "R"),
174 (Dummy2PipelineTask, "S"),
175 (Dummy3PipelineTask, "T"),
176 (Dummy4PipelineTask, "U"),
177 ):
178 config = task.ConfigClass()
179 taskDef = TaskDef(get_full_type_name(task), config, task, label)
180 tasks.append(taskDef)
181 quantumSet = set()
182 connections = taskDef.connections
183 if connections.initInputs:
184 initInputDSType = _makeDatasetType(connections.initInput)
185 if (ref := init_dataset_refs.get(initInputDSType)) is not None:
186 initRefs = [ref]
187 else:
188 initRefs = [
189 DatasetRef(
190 initInputDSType,
191 DataCoordinate.makeEmpty(universe),
192 run=self.input_collection,
193 )
194 ]
195 initInputs[taskDef] = initRefs
196 dataset_types.add(initInputDSType)
197 else:
198 initRefs = None
199 if connections.initOutputs:
200 initOutputDSType = _makeDatasetType(connections.initOutput)
201 initRefs = [
202 DatasetRef(initOutputDSType, DataCoordinate.makeEmpty(universe), run=self.output_run)
203 ]
204 init_dataset_refs[initOutputDSType] = initRefs[0]
205 initOutputs[taskDef] = initRefs
206 dataset_types.add(initOutputDSType)
207 inputDSType = _makeDatasetType(connections.input)
208 dataset_types.add(inputDSType)
209 outputDSType = _makeDatasetType(connections.output)
210 dataset_types.add(outputDSType)
211 for a, b in ((1, 2), (3, 4)):
212 dataId = DataCoordinate.standardize({"A": a, "B": b}, universe=universe)
213 if (ref := dataset_refs.get((inputDSType, dataId))) is None:
214 inputRefs = [DatasetRef(inputDSType, dataId, run=self.input_collection)]
215 else:
216 inputRefs = [ref]
217 outputRefs = [DatasetRef(outputDSType, dataId, run=self.output_run)]
218 dataset_refs[(outputDSType, dataId)] = outputRefs[0]
219 quantumSet.add(
220 Quantum(
221 taskName=task.__qualname__,
222 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
223 taskClass=task,
224 initInputs=initRefs,
225 inputs={inputDSType: inputRefs},
226 outputs={outputDSType: outputRefs},
227 )
228 )
229 quantumMap[taskDef] = quantumSet
230 self.tasks = tasks
231 self.quantumMap = quantumMap
232 self.packagesDSType = DatasetType("packages", universe.empty, storageClass="Packages")
233 dataset_types.add(self.packagesDSType)
234 globalInitOutputs = [
235 DatasetRef(self.packagesDSType, DataCoordinate.makeEmpty(universe), run=self.output_run)
236 ]
237 self.qGraph = QuantumGraph(
238 quantumMap,
239 metadata=METADATA,
240 universe=universe,
241 initInputs=initInputs,
242 initOutputs=initOutputs,
243 globalInitOutputs=globalInitOutputs,
244 registryDatasetTypes=dataset_types,
245 )
246 self.universe = universe
247 self.num_dataset_types = len(dataset_types)
249 def testTaskGraph(self) -> None:
250 for taskDef in self.quantumMap.keys():
251 self.assertIn(taskDef, self.qGraph.taskGraph)
253 def testGraph(self) -> None:
254 graphSet = {q.quantum for q in self.qGraph.graph}
255 for quantum in chain.from_iterable(self.quantumMap.values()):
256 self.assertIn(quantum, graphSet)
258 def testGetQuantumNodeByNodeId(self) -> None:
259 inputQuanta = tuple(self.qGraph.inputQuanta)
260 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
261 self.assertEqual(node, inputQuanta[0])
262 wrongNode = uuid.uuid4()
263 with self.assertRaises(KeyError):
264 self.qGraph.getQuantumNodeByNodeId(wrongNode)
266 def testPickle(self) -> None:
267 stringify = pickle.dumps(self.qGraph)
268 restore: QuantumGraph = pickle.loads(stringify)
269 self.assertEqual(self.qGraph, restore)
271 def testInputQuanta(self) -> None:
272 inputs = {q.quantum for q in self.qGraph.inputQuanta}
273 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs)
275 def testOutputQuanta(self) -> None:
276 outputs = {q.quantum for q in self.qGraph.outputQuanta}
277 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs)
279 def testLength(self) -> None:
280 self.assertEqual(len(self.qGraph), 2 * len(self.tasks))
282 def testGetQuantaForTask(self) -> None:
283 for task in self.tasks:
284 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
286 def testGetNumberOfQuantaForTask(self) -> None:
287 for task in self.tasks:
288 self.assertEqual(self.qGraph.getNumberOfQuantaForTask(task), len(self.quantumMap[task]))
290 def testGetNodesForTask(self) -> None:
291 for task in self.tasks:
292 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task)
293 quanta_in_node = set(n.quantum for n in nodes)
294 self.assertEqual(quanta_in_node, self.quantumMap[task])
296 def testFindTasksWithInput(self) -> None:
297 self.assertEqual(
298 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1]
299 )
301 def testFindTasksWithOutput(self) -> None:
302 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
304 def testTaskWithDSType(self) -> None:
305 self.assertEqual(
306 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2])
307 )
309 def testFindTaskDefByName(self) -> None:
310 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0])
312 def testFindTaskDefByLabel(self) -> None:
313 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0])
315 def testFindQuantaWIthDSType(self) -> None:
316 self.assertEqual(
317 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]]
318 )
320 def testAllDatasetTypes(self) -> None:
321 allDatasetTypes = set(self.qGraph.allDatasetTypes)
322 truth = set()
323 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections):
324 for connection in conClass.allConnections.values(): # type: ignore
325 if not isinstance(connection, cT.InitOutput):
326 truth.add(connection.name)
327 self.assertEqual(allDatasetTypes, truth)
329 def testSubset(self) -> None:
330 allNodes = list(self.qGraph)
331 firstNode = allNodes[0]
332 subset = self.qGraph.subset(firstNode)
333 self.assertEqual(len(subset), 1)
334 subsetList = list(subset)
335 self.assertEqual(firstNode.quantum, subsetList[0].quantum)
336 self.assertEqual(self.qGraph._buildId, subset._buildId)
337 self.assertEqual(len(subset.globalInitOutputRefs()), 1)
338 # Depending on which task was first the list can contain different
339 # number of datasets. The first task can be either Dummy1 or Dummy4.
340 num_types = {"R": 4, "U": 3}
341 self.assertEqual(len(subset.registryDatasetTypes()), num_types[firstNode.taskDef.label])
343 def testSubsetToConnected(self) -> None:
344 # False because there are two quantum chains for two distinct sets of
345 # dimensions
346 self.assertFalse(self.qGraph.isConnected)
348 connectedGraphs = self.qGraph.subsetToConnected()
349 self.assertEqual(len(connectedGraphs), 4)
350 self.assertTrue(connectedGraphs[0].isConnected)
351 self.assertTrue(connectedGraphs[1].isConnected)
352 self.assertTrue(connectedGraphs[2].isConnected)
353 self.assertTrue(connectedGraphs[3].isConnected)
355 # Split out task[3] because it is expected to be on its own
356 for cg in connectedGraphs:
357 if self.tasks[3] in cg.taskGraph:
358 self.assertEqual(len(cg), 1)
359 else:
360 self.assertEqual(len(cg), 3)
362 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
364 count = 0
365 for node in self.qGraph:
366 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
367 count += 1
368 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
369 count += 1
370 if connectedGraphs[2].checkQuantumInGraph(node.quantum):
371 count += 1
372 if connectedGraphs[3].checkQuantumInGraph(node.quantum):
373 count += 1
374 self.assertEqual(len(self.qGraph), count)
376 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs}
377 for setLen, tskSet in taskSets.items():
378 if setLen == 3:
379 self.assertEqual(set(self.tasks[:-1]), tskSet)
380 elif setLen == 1:
381 self.assertEqual({self.tasks[-1]}, tskSet)
382 for cg in connectedGraphs:
383 if len(cg.taskGraph) == 1:
384 continue
385 allNodes = list(cg)
386 nodes = cg.determineInputsToQuantumNode(allNodes[1])
387 self.assertEqual(set([allNodes[0]]), nodes)
388 nodes = cg.determineInputsToQuantumNode(allNodes[1])
389 self.assertEqual(set([allNodes[0]]), nodes)
391 def testDetermineOutputsOfQuantumNode(self) -> None:
392 testNodes = self.qGraph.getNodesForTask(self.tasks[0])
393 matchNodes = self.qGraph.getNodesForTask(self.tasks[1])
394 connections = set()
395 for node in testNodes:
396 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node))
397 self.assertEqual(matchNodes, connections)
399 def testDetermineConnectionsOfQuantum(self) -> None:
400 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
401 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2])
402 # outputs contain nodes tested for because it is a complete graph
403 matchNodes |= set(testNodes)
404 connections = set()
405 for node in testNodes:
406 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node))
407 self.assertEqual(matchNodes, connections)
409 def testDetermineAnsestorsOfQuantumNode(self) -> None:
410 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
411 matchNodes = self.qGraph.getNodesForTask(self.tasks[0])
412 matchNodes |= set(testNodes)
413 connections = set()
414 for node in testNodes:
415 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node))
416 self.assertEqual(matchNodes, connections)
418 def testFindCycle(self) -> None:
419 self.assertFalse(self.qGraph.findCycle())
421 def testSaveLoad(self) -> None:
422 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile:
423 self.qGraph.save(tmpFile)
424 tmpFile.seek(0)
425 restore = QuantumGraph.load(tmpFile, self.universe)
426 self.assertEqual(self.qGraph, restore)
427 # Load in just one node
428 tmpFile.seek(0)
429 nodeId = [n.nodeId for n in self.qGraph][0]
430 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,))
431 self.assertEqual(len(restoreSub), 1)
432 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
433 self.assertEqual(len(restoreSub.globalInitOutputRefs()), 1)
434 self.assertEqual(len(restoreSub.registryDatasetTypes()), self.num_dataset_types)
435 # Check that InitInput and InitOutput refs are restored correctly.
436 for taskDef in restore.iterTaskGraph():
437 if taskDef.label in ("S", "T"):
438 refs = restore.initInputRefs(taskDef)
439 self.assertIsNotNone(refs)
440 assert refs is not None
441 self.assertGreater(len(refs), 0)
442 if taskDef.label in ("R", "S", "T"):
443 refs = restore.initOutputRefs(taskDef)
444 self.assertIsNotNone(refs)
445 assert refs is not None
446 self.assertGreater(len(refs), 0)
448 # Different universes.
449 tmpFile.seek(0)
450 different_config = self.config.copy()
451 different_config["version"] = 1_000_000
452 different_universe = DimensionUniverse(config=different_config)
453 with self.assertLogs("lsst.daf.butler", "INFO"):
454 QuantumGraph.load(tmpFile, different_universe)
456 different_config["namespace"] = "incompatible"
457 different_universe = DimensionUniverse(config=different_config)
458 print("Trying with uni ", different_universe)
459 tmpFile.seek(0)
460 with self.assertRaises(RuntimeError) as cm:
461 QuantumGraph.load(tmpFile, different_universe)
462 self.assertIn("not compatible with", str(cm.exception))
464 def testSaveLoadUri(self) -> None:
465 uri = None
466 try:
467 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile:
468 uri = tmpFile.name
469 self.qGraph.saveUri(uri)
470 restore = QuantumGraph.loadUri(uri)
471 self.assertEqual(restore.metadata, METADATA)
472 self.assertEqual(self.qGraph, restore)
473 nodeNumberId = random.randint(0, len(self.qGraph) - 1)
474 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId]
475 restoreSub = QuantumGraph.loadUri(
476 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId
477 )
478 self.assertEqual(len(restoreSub), 1)
479 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber))
480 # verify that more than one node works
481 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
482 # ensure it is a different node number
483 while nodeNumberId2 == nodeNumberId:
484 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
485 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2]
486 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2))
487 self.assertEqual(len(restoreSub), 2)
488 self.assertEqual(
489 set(restoreSub),
490 set(
491 (
492 restore.getQuantumNodeByNodeId(nodeNumber),
493 restore.getQuantumNodeByNodeId(nodeNumber2),
494 )
495 ),
496 )
497 # verify an error when requesting a non existant node number
498 with self.assertRaises(ValueError):
499 QuantumGraph.loadUri(uri, self.universe, nodes=(uuid.uuid4(),))
501 # verify a graphID that does not match will be an error
502 with self.assertRaises(ValueError):
503 QuantumGraph.loadUri(uri, self.universe, graphID=BuildId("NOTRIGHT"))
505 except Exception as e:
506 raise e
507 finally:
508 if uri is not None:
509 os.remove(uri)
511 with self.assertRaises(TypeError):
512 self.qGraph.saveUri("test.notgraph")
514 def testSaveLoadNoRegistryDatasetTypes(self) -> None:
515 """Test for reading quantum that is missing registry dataset types.
517 This test depends on internals of QuantumGraph implementation, in
518 particular that empty list of registry dataset types is not stored,
519 which makes save file identical to the "old" format.
520 """
521 # Reset the list, this is safe as QuantumGraph itself does not use it.
522 self.qGraph._registryDatasetTypes = []
523 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile:
524 self.qGraph.save(tmpFile)
525 tmpFile.seek(0)
526 restore = QuantumGraph.load(tmpFile, self.universe)
527 self.assertEqual(self.qGraph, restore)
528 self.assertEqual(restore.registryDatasetTypes(), [])
530 def testContains(self) -> None:
531 firstNode = next(iter(self.qGraph))
532 self.assertIn(firstNode, self.qGraph)
534 def testDimensionUniverseInSave(self) -> None:
535 _, header = self.qGraph._buildSaveObject(returnHeader=True)
536 # type ignore because buildSaveObject does not have method overload
537 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore
539 def testUpdateRun(self) -> None:
540 """Test for QuantumGraph.updateRun method."""
542 self.assertEqual(check_output_run(self.qGraph, self.output_run), [])
543 graph_id = self.qGraph.graphID
545 self.qGraph.updateRun("updated-run")
546 self.assertEqual(check_output_run(self.qGraph, "updated-run"), [])
547 self.assertEqual(self.qGraph.graphID, graph_id)
549 # Also update metadata.
550 self.qGraph.updateRun("updated-run2", metadata_key="ouput_run")
551 self.assertEqual(check_output_run(self.qGraph, "updated-run2"), [])
552 self.assertEqual(self.qGraph.graphID, graph_id)
553 assert self.qGraph.metadata is not None
554 self.assertIn("ouput_run", self.qGraph.metadata)
555 self.assertEqual(self.qGraph.metadata["ouput_run"], "updated-run2")
557 # Update graph ID.
558 self.qGraph.updateRun("updated-run3", metadata_key="ouput_run", update_graph_id=True)
559 self.assertEqual(check_output_run(self.qGraph, "updated-run3"), [])
560 self.assertNotEqual(self.qGraph.graphID, graph_id)
563class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
564 pass
567def setup_module(module) -> None:
568 lsst.utils.tests.init()
571if __name__ == "__main__":
572 lsst.utils.tests.init()
573 unittest.main()