Coverage for tests/test_quantumGraph.py : 31%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from itertools import chain
23import pickle
24import tempfile
25import unittest
26from lsst.daf.butler import DimensionUniverse
28from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections,
29 DatasetTypeName, IncompatibleGraphError)
30import lsst.pipe.base.connectionTypes as cT
31from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config
32from lsst.pex.config import Field
33from lsst.pipe.base.graph.quantumNode import NodeId, BuildId
34import lsst.utils.tests
37class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
38 initOutput = cT.InitOutput(name="Dummy1InitOutput",
39 storageClass="ExposureF",
40 doc="n/a")
41 input = cT.Input(name="Dummy1Input",
42 storageClass="ExposureF",
43 doc="n/a",
44 dimensions=("A", "B"))
45 output = cT.Output(name="Dummy1Output",
46 storageClass="ExposureF",
47 doc="n/a",
48 dimensions=("A", "B"))
51class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
52 conf1 = Field(dtype=int, default=1, doc="dummy config")
55class Dummy1PipelineTask(PipelineTask):
56 ConfigClass = Dummy1Config
59class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
60 initInput = cT.InitInput(name="Dummy1InitOutput",
61 storageClass="ExposureF",
62 doc="n/a")
63 initOutput = cT.InitOutput(name="Dummy2InitOutput",
64 storageClass="ExposureF",
65 doc="n/a")
66 input = cT.Input(name="Dummy1Output",
67 storageClass="ExposureF",
68 doc="n/a",
69 dimensions=("A", "B"))
70 output = cT.Output(name="Dummy2Output",
71 storageClass="ExposureF",
72 doc="n/a",
73 dimensions=("A", "B"))
76class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
77 conf1 = Field(dtype=int, default=1, doc="dummy config")
80class Dummy2PipelineTask(PipelineTask):
81 ConfigClass = Dummy2Config
84class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
85 initInput = cT.InitInput(name="Dummy2InitOutput",
86 storageClass="ExposureF",
87 doc="n/a")
88 initOutput = cT.InitOutput(name="Dummy3InitOutput",
89 storageClass="ExposureF",
90 doc="n/a")
91 input = cT.Input(name="Dummy2Output",
92 storageClass="ExposureF",
93 doc="n/a",
94 dimensions=("A", "B"))
95 output = cT.Output(name="Dummy3Output",
96 storageClass="ExposureF",
97 doc="n/a",
98 dimensions=("A", "B"))
101class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
102 conf1 = Field(dtype=int, default=1, doc="dummy config")
105class Dummy3PipelineTask(PipelineTask):
106 ConfigClass = Dummy3Config
109class QuantumGraphTestCase(unittest.TestCase):
110 """Tests the various functions of a quantum graph
111 """
112 def setUp(self):
113 config = Config({
114 "dimensions": {
115 "version": 1,
116 "skypix": {},
117 "elements": {
118 "A": {
119 "keys": [{
120 "name": "id",
121 "type": "int",
122 }],
123 },
124 "B": {
125 "keys": [{
126 "name": "id",
127 "type": "int",
128 }],
129 }
130 }
131 }
132 })
133 universe = DimensionUniverse(config=config)
134 # need to make a mapping of TaskDef to set of quantum
135 quantumMap = {}
136 tasks = []
137 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")):
138 config = task.ConfigClass()
139 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label)
140 tasks.append(taskDef)
141 quantumSet = set()
142 connections = taskDef.connections
143 for a, b in ((1, 2), (3, 4)):
144 if connections.initInputs:
145 initInputDSType = DatasetType(connections.initInput.name,
146 tuple(),
147 storageClass=connections.initInput.storageClass,
148 universe=universe)
149 initRefs = [DatasetRef(initInputDSType,
150 DataCoordinate.makeEmpty(universe))]
151 else:
152 initRefs = None
153 inputDSType = DatasetType(connections.input.name,
154 connections.input.dimensions,
155 storageClass=connections.input.storageClass,
156 universe=universe,
157 )
158 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b},
159 universe=universe))]
160 outputDSType = DatasetType(connections.output.name,
161 connections.output.dimensions,
162 storageClass=connections.output.storageClass,
163 universe=universe,
164 )
165 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b},
166 universe=universe))]
167 quantumSet.add(
168 Quantum(taskName=task.__qualname__,
169 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
170 taskClass=task,
171 initInputs=initRefs,
172 inputs={inputDSType: inputRefs},
173 outputs={outputDSType: outputRefs}
174 )
175 )
176 quantumMap[taskDef] = quantumSet
177 self.tasks = tasks
178 self.quantumMap = quantumMap
179 self.qGraph = QuantumGraph(quantumMap)
180 self.universe = universe
182 def testTaskGraph(self):
183 for taskDef in self.quantumMap.keys():
184 self.assertIn(taskDef, self.qGraph.taskGraph)
186 def testGraph(self):
187 graphSet = {q.quantum for q in self.qGraph.graph}
188 for quantum in chain.from_iterable(self.quantumMap.values()):
189 self.assertIn(quantum, graphSet)
191 def testGetQuantumNodeByNodeId(self):
192 inputQuanta = tuple(self.qGraph.inputQuanta)
193 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
194 self.assertEqual(node, inputQuanta[0])
195 wrongNode = NodeId(15, BuildId("alternative build Id"))
196 with self.assertRaises(IncompatibleGraphError):
197 self.qGraph.getQuantumNodeByNodeId(wrongNode)
199 def testPickle(self):
200 stringify = pickle.dumps(self.qGraph)
201 restore: QuantumGraph = pickle.loads(stringify)
202 # This is a hack for the unit test since the qualified name will be
203 # different as it will be __main__ here, but qualified to the
204 # unittest module name when restored
205 for saved, loaded in zip(self.qGraph._quanta.keys(),
206 restore._quanta.keys()):
207 saved.taskName = saved.taskName.split('.')[-1]
208 loaded.taskName = loaded.taskName.split('.')[-1]
209 self.assertEqual(self.qGraph, restore)
211 def testInputQuanta(self):
212 inputs = {q.quantum for q in self.qGraph.inputQuanta}
213 self.assertEqual(self.quantumMap[self.tasks[0]], inputs)
215 def testOutputtQuanta(self):
216 outputs = {q.quantum for q in self.qGraph.outputQuanta}
217 self.assertEqual(self.quantumMap[self.tasks[-1]], outputs)
219 def testLength(self):
220 self.assertEqual(len(self.qGraph), 6)
222 def testGetQuantaForTask(self):
223 for task in self.tasks:
224 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
226 def testFindTasksWithInput(self):
227 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0],
228 self.tasks[1])
230 def testFindTasksWithOutput(self):
231 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
233 def testTaskWithDSType(self):
234 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))),
235 set(self.tasks[:2]))
237 def testFindTaskDefByName(self):
238 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0],
239 self.tasks[0])
241 def testFindTaskDefByLabel(self):
242 self.assertEqual(self.qGraph.findTaskDefByLabel("R"),
243 self.tasks[0])
245 def testFindQuantaWIthDSType(self):
246 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")),
247 self.quantumMap[self.tasks[0]])
249 def testAllDatasetTypes(self):
250 allDatasetTypes = set(self.qGraph.allDatasetTypes)
251 truth = set()
252 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections):
253 for connection in conClass.allConnections.values(): # type: ignore
254 truth.add(connection.name)
255 self.assertEqual(allDatasetTypes, truth)
257 def testSubset(self):
258 allNodes = list(self.qGraph)
259 subset = self.qGraph.subset(allNodes[0])
260 self.assertEqual(len(subset), 1)
261 subsetList = list(subset)
262 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
263 self.assertEqual(self.qGraph._buildId, subset._buildId)
265 def testIsConnected(self):
266 # False because there are two quantum chains for two distinct sets of
267 # dimensions
268 self.assertFalse(self.qGraph.isConnected)
269 # make a broken subset
270 allNodes = list(self.qGraph)
271 subset = self.qGraph.subset((allNodes[0], allNodes[1]))
272 # True because we subset to only one chain of graphs
273 self.assertTrue(subset.isConnected)
275 def testSubsetToConnected(self):
276 connectedGraphs = self.qGraph.subsetToConnected()
277 self.assertEqual(len(connectedGraphs), 2)
278 self.assertTrue(connectedGraphs[0].isConnected)
279 self.assertTrue(connectedGraphs[1].isConnected)
281 self.assertEqual(len(connectedGraphs[0]), 3)
282 self.assertEqual(len(connectedGraphs[1]), 3)
284 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
286 count = 0
287 for node in self.qGraph:
288 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
289 count += 1
290 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
291 count += 1
292 self.assertEqual(len(self.qGraph), count)
294 self.assertEqual(self.tasks, list(connectedGraphs[0].taskGraph))
295 self.assertEqual(self.tasks, list(connectedGraphs[1].taskGraph))
296 allNodes = list(self.qGraph)
297 node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
298 self.assertEqual(set([allNodes[0]]), node)
299 node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
300 self.assertEqual(set([allNodes[0]]), node)
302 def testDetermineOutputsOfQuantumNode(self):
303 allNodes = list(self.qGraph)
304 node = next(iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1])))
305 self.assertEqual(allNodes[2], node)
307 def testDetermineConnectionsOfQuantum(self):
308 allNodes = list(self.qGraph)
309 connections = self.qGraph.determineConnectionsOfQuantumNode(allNodes[1])
310 self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3])))
312 def testDetermineAnsestorsOfQuantumNode(self):
313 allNodes = list(self.qGraph)
314 ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2])
315 self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3])))
317 def testFindCycle(self):
318 self.assertFalse(self.qGraph.findCycle())
320 def testSaveLoad(self):
321 with tempfile.TemporaryFile() as tmpFile:
322 self.qGraph.save(tmpFile)
323 tmpFile.seek(0)
324 restore = QuantumGraph.load(tmpFile, self.universe)
325 # This is a hack for the unit test since the qualified name will be
326 # different as it will be __main__ here, but qualified to the
327 # unittest module name when restored
328 for saved, loaded in zip(self.qGraph._quanta.keys(),
329 restore._quanta.keys()):
330 saved.taskName = saved.taskName.split('.')[-1]
331 loaded.taskName = loaded.taskName.split('.')[-1]
332 self.assertEqual(self.qGraph, restore)
334 def testContains(self):
335 firstNode = next(iter(self.qGraph))
336 self.assertIn(firstNode, self.qGraph)
339class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
340 pass
343def setup_module(module):
344 lsst.utils.tests.init()
347if __name__ == "__main__": 347 ↛ 348line 347 didn't jump to line 348, because the condition on line 347 was never true
348 lsst.utils.tests.init()
349 unittest.main()