Coverage for tests/test_quantumGraph.py: 21%
293 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-23 03:02 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-23 03:02 -0800
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import os
23import pickle
24import random
25import tempfile
26import unittest
27import uuid
28from itertools import chain
29from typing import Iterable
31import lsst.pipe.base.connectionTypes as cT
32import lsst.utils.tests
33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum
34from lsst.pex.config import Field
35from lsst.pipe.base import (
36 DatasetTypeName,
37 PipelineTask,
38 PipelineTaskConfig,
39 PipelineTaskConnections,
40 QuantumGraph,
41 TaskDef,
42)
43from lsst.pipe.base.graph.quantumNode import QuantumNode
44from lsst.utils.introspection import get_full_type_name
46METADATA = {"a": [1, 2, 3]}
49class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
50 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
51 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
52 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
55class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
56 conf1 = Field(dtype=int, default=1, doc="dummy config")
59class Dummy1PipelineTask(PipelineTask):
60 ConfigClass = Dummy1Config
63class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
64 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
65 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
66 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
67 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
70class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
71 conf1 = Field(dtype=int, default=1, doc="dummy config")
74class Dummy2PipelineTask(PipelineTask):
75 ConfigClass = Dummy2Config
78class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
79 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
80 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a")
81 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
82 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
85class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
86 conf1 = Field(dtype=int, default=1, doc="dummy config")
89class Dummy3PipelineTask(PipelineTask):
90 ConfigClass = Dummy3Config
93# Test if a Task that does not interact with the other Tasks works fine in
94# the graph.
95class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")):
96 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
97 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
100class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections):
101 conf1 = Field(dtype=int, default=1, doc="dummy config")
104class Dummy4PipelineTask(PipelineTask):
105 ConfigClass = Dummy4Config
108class QuantumGraphTestCase(unittest.TestCase):
109 """Tests the various functions of a quantum graph"""
111 def setUp(self):
112 self.config = Config(
113 {
114 "version": 1,
115 "namespace": "pipe_base_test",
116 "skypix": {
117 "common": "htm7",
118 "htm": {
119 "class": "lsst.sphgeom.HtmPixelization",
120 "max_level": 24,
121 },
122 },
123 "elements": {
124 "A": {
125 "keys": [
126 {
127 "name": "id",
128 "type": "int",
129 }
130 ],
131 "storage": {
132 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
133 },
134 },
135 "B": {
136 "keys": [
137 {
138 "name": "id",
139 "type": "int",
140 }
141 ],
142 "storage": {
143 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
144 },
145 },
146 },
147 "packers": {},
148 }
149 )
150 universe = DimensionUniverse(config=self.config)
152 def _makeDatasetType(connection):
153 return DatasetType(
154 connection.name,
155 getattr(connection, "dimensions", ()),
156 storageClass=connection.storageClass,
157 universe=universe,
158 )
160 # need to make a mapping of TaskDef to set of quantum
161 quantumMap = {}
162 tasks = []
163 initInputs = {}
164 initOutputs = {}
165 for task, label in (
166 (Dummy1PipelineTask, "R"),
167 (Dummy2PipelineTask, "S"),
168 (Dummy3PipelineTask, "T"),
169 (Dummy4PipelineTask, "U"),
170 ):
171 config = task.ConfigClass()
172 taskDef = TaskDef(get_full_type_name(task), config, task, label)
173 tasks.append(taskDef)
174 quantumSet = set()
175 connections = taskDef.connections
176 if connections.initInputs:
177 initInputDSType = _makeDatasetType(connections.initInput)
178 initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))]
179 initInputs[taskDef] = initRefs
180 else:
181 initRefs = None
182 if connections.initOutputs:
183 initOutputDSType = _makeDatasetType(connections.initOutput)
184 initRefs = [DatasetRef(initOutputDSType, DataCoordinate.makeEmpty(universe))]
185 initOutputs[taskDef] = initRefs
186 inputDSType = _makeDatasetType(connections.input)
187 outputDSType = _makeDatasetType(connections.output)
188 for a, b in ((1, 2), (3, 4)):
189 inputRefs = [
190 DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
191 ]
192 outputRefs = [
193 DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
194 ]
195 quantumSet.add(
196 Quantum(
197 taskName=task.__qualname__,
198 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
199 taskClass=task,
200 initInputs=initRefs,
201 inputs={inputDSType: inputRefs},
202 outputs={outputDSType: outputRefs},
203 )
204 )
205 quantumMap[taskDef] = quantumSet
206 self.tasks = tasks
207 self.quantumMap = quantumMap
208 self.packagesDSType = DatasetType("packages", universe.empty, storageClass="Packages")
209 globalInitOutputs = [DatasetRef(self.packagesDSType, DataCoordinate.makeEmpty(universe))]
210 self.qGraph = QuantumGraph(
211 quantumMap,
212 metadata=METADATA,
213 universe=universe,
214 initInputs=initInputs,
215 initOutputs=initOutputs,
216 globalInitOutputs=globalInitOutputs,
217 )
218 self.universe = universe
220 def testTaskGraph(self):
221 for taskDef in self.quantumMap.keys():
222 self.assertIn(taskDef, self.qGraph.taskGraph)
224 def testGraph(self):
225 graphSet = {q.quantum for q in self.qGraph.graph}
226 for quantum in chain.from_iterable(self.quantumMap.values()):
227 self.assertIn(quantum, graphSet)
229 def testGetQuantumNodeByNodeId(self):
230 inputQuanta = tuple(self.qGraph.inputQuanta)
231 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
232 self.assertEqual(node, inputQuanta[0])
233 wrongNode = uuid.uuid4()
234 with self.assertRaises(KeyError):
235 self.qGraph.getQuantumNodeByNodeId(wrongNode)
237 def testPickle(self):
238 stringify = pickle.dumps(self.qGraph)
239 restore: QuantumGraph = pickle.loads(stringify)
240 self.assertEqual(self.qGraph, restore)
242 def testInputQuanta(self):
243 inputs = {q.quantum for q in self.qGraph.inputQuanta}
244 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs)
246 def testOutputQuanta(self):
247 outputs = {q.quantum for q in self.qGraph.outputQuanta}
248 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs)
250 def testLength(self):
251 self.assertEqual(len(self.qGraph), 2 * len(self.tasks))
253 def testGetQuantaForTask(self):
254 for task in self.tasks:
255 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
257 def testGetNumberOfQuantaForTask(self):
258 for task in self.tasks:
259 self.assertEqual(self.qGraph.getNumberOfQuantaForTask(task), len(self.quantumMap[task]))
261 def testGetNodesForTask(self):
262 for task in self.tasks:
263 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task)
264 quanta_in_node = set(n.quantum for n in nodes)
265 self.assertEqual(quanta_in_node, self.quantumMap[task])
267 def testFindTasksWithInput(self):
268 self.assertEqual(
269 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1]
270 )
272 def testFindTasksWithOutput(self):
273 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
275 def testTaskWithDSType(self):
276 self.assertEqual(
277 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2])
278 )
280 def testFindTaskDefByName(self):
281 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0])
283 def testFindTaskDefByLabel(self):
284 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0])
286 def testFindQuantaWIthDSType(self):
287 self.assertEqual(
288 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]]
289 )
291 def testAllDatasetTypes(self):
292 allDatasetTypes = set(self.qGraph.allDatasetTypes)
293 truth = set()
294 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections):
295 for connection in conClass.allConnections.values(): # type: ignore
296 if not isinstance(connection, cT.InitOutput):
297 truth.add(connection.name)
298 self.assertEqual(allDatasetTypes, truth)
300 def testSubset(self):
301 allNodes = list(self.qGraph)
302 subset = self.qGraph.subset(allNodes[0])
303 self.assertEqual(len(subset), 1)
304 subsetList = list(subset)
305 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
306 self.assertEqual(self.qGraph._buildId, subset._buildId)
307 self.assertEqual(len(subset.globalInitOutputRefs()), 1)
309 def testSubsetToConnected(self):
310 # False because there are two quantum chains for two distinct sets of
311 # dimensions
312 self.assertFalse(self.qGraph.isConnected)
314 connectedGraphs = self.qGraph.subsetToConnected()
315 self.assertEqual(len(connectedGraphs), 4)
316 self.assertTrue(connectedGraphs[0].isConnected)
317 self.assertTrue(connectedGraphs[1].isConnected)
318 self.assertTrue(connectedGraphs[2].isConnected)
319 self.assertTrue(connectedGraphs[3].isConnected)
321 # Split out task[3] because it is expected to be on its own
322 for cg in connectedGraphs:
323 if self.tasks[3] in cg.taskGraph:
324 self.assertEqual(len(cg), 1)
325 else:
326 self.assertEqual(len(cg), 3)
328 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
330 count = 0
331 for node in self.qGraph:
332 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
333 count += 1
334 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
335 count += 1
336 if connectedGraphs[2].checkQuantumInGraph(node.quantum):
337 count += 1
338 if connectedGraphs[3].checkQuantumInGraph(node.quantum):
339 count += 1
340 self.assertEqual(len(self.qGraph), count)
342 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs}
343 for setLen, tskSet in taskSets.items():
344 if setLen == 3:
345 self.assertEqual(set(self.tasks[:-1]), tskSet)
346 elif setLen == 1:
347 self.assertEqual({self.tasks[-1]}, tskSet)
348 for cg in connectedGraphs:
349 if len(cg.taskGraph) == 1:
350 continue
351 allNodes = list(cg)
352 node = cg.determineInputsToQuantumNode(allNodes[1])
353 self.assertEqual(set([allNodes[0]]), node)
354 node = cg.determineInputsToQuantumNode(allNodes[1])
355 self.assertEqual(set([allNodes[0]]), node)
357 def testDetermineOutputsOfQuantumNode(self):
358 testNodes = self.qGraph.getNodesForTask(self.tasks[0])
359 matchNodes = self.qGraph.getNodesForTask(self.tasks[1])
360 connections = set()
361 for node in testNodes:
362 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node))
363 self.assertEqual(matchNodes, connections)
365 def testDetermineConnectionsOfQuantum(self):
366 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
367 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2])
368 # outputs contain nodes tested for because it is a complete graph
369 matchNodes |= set(testNodes)
370 connections = set()
371 for node in testNodes:
372 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node))
373 self.assertEqual(matchNodes, connections)
375 def testDetermineAnsestorsOfQuantumNode(self):
376 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
377 matchNodes = self.qGraph.getNodesForTask(self.tasks[0])
378 matchNodes |= set(testNodes)
379 connections = set()
380 for node in testNodes:
381 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node))
382 self.assertEqual(matchNodes, connections)
384 def testFindCycle(self):
385 self.assertFalse(self.qGraph.findCycle())
387 def testSaveLoad(self):
388 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile:
389 self.qGraph.save(tmpFile)
390 tmpFile.seek(0)
391 restore = QuantumGraph.load(tmpFile, self.universe)
392 self.assertEqual(self.qGraph, restore)
393 # Load in just one node
394 tmpFile.seek(0)
395 nodeId = [n.nodeId for n in self.qGraph][0]
396 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,))
397 self.assertEqual(len(restoreSub), 1)
398 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
399 self.assertEqual(len(restoreSub.globalInitOutputRefs()), 1)
400 # Check that InitInput and InitOutput refs are restored correctly.
401 for taskDef in restore.iterTaskGraph():
402 if taskDef.label in ("S", "T"):
403 refs = restore.initInputRefs(taskDef)
404 self.assertIsNotNone(refs)
405 self.assertGreater(len(refs), 0)
406 if taskDef.label in ("R", "S", "T"):
407 refs = restore.initOutputRefs(taskDef)
408 self.assertIsNotNone(refs)
409 self.assertGreater(len(refs), 0)
411 # Different universes.
412 tmpFile.seek(0)
413 different_config = self.config.copy()
414 different_config["version"] = 1_000_000
415 different_universe = DimensionUniverse(config=different_config)
416 with self.assertLogs("lsst.daf.butler", "INFO"):
417 QuantumGraph.load(tmpFile, different_universe)
419 different_config["namespace"] = "incompatible"
420 different_universe = DimensionUniverse(config=different_config)
421 print("Trying with uni ", different_universe)
422 tmpFile.seek(0)
423 with self.assertRaises(RuntimeError) as cm:
424 QuantumGraph.load(tmpFile, different_universe)
425 self.assertIn("not compatible with", str(cm.exception))
427 def testSaveLoadUri(self):
428 uri = None
429 try:
430 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile:
431 uri = tmpFile.name
432 self.qGraph.saveUri(uri)
433 restore = QuantumGraph.loadUri(uri)
434 self.assertEqual(restore.metadata, METADATA)
435 self.assertEqual(self.qGraph, restore)
436 nodeNumberId = random.randint(0, len(self.qGraph) - 1)
437 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId]
438 restoreSub = QuantumGraph.loadUri(
439 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId
440 )
441 self.assertEqual(len(restoreSub), 1)
442 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber))
443 # verify that more than one node works
444 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
445 # ensure it is a different node number
446 while nodeNumberId2 == nodeNumberId:
447 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
448 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2]
449 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2))
450 self.assertEqual(len(restoreSub), 2)
451 self.assertEqual(
452 set(restoreSub),
453 set(
454 (
455 restore.getQuantumNodeByNodeId(nodeNumber),
456 restore.getQuantumNodeByNodeId(nodeNumber2),
457 )
458 ),
459 )
460 # verify an error when requesting a non existant node number
461 with self.assertRaises(ValueError):
462 QuantumGraph.loadUri(uri, self.universe, nodes=(99,))
464 # verify a graphID that does not match will be an error
465 with self.assertRaises(ValueError):
466 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT")
468 except Exception as e:
469 raise e
470 finally:
471 if uri is not None:
472 os.remove(uri)
474 with self.assertRaises(TypeError):
475 self.qGraph.saveUri("test.notgraph")
477 def testContains(self):
478 firstNode = next(iter(self.qGraph))
479 self.assertIn(firstNode, self.qGraph)
481 def testDimensionUniverseInSave(self):
482 _, header = self.qGraph._buildSaveObject(returnHeader=True)
483 # type ignore because buildSaveObject does not have method overload
484 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore
487class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
488 pass
491def setup_module(module):
492 lsst.utils.tests.init()
495if __name__ == "__main__": 495 ↛ 496line 495 didn't jump to line 496, because the condition on line 495 was never true
496 lsst.utils.tests.init()
497 unittest.main()