Coverage for tests/test_quantumGraph.py : 31%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from itertools import chain
23import pickle
24import tempfile
25import unittest
26from lsst.daf.butler import DimensionUniverse
28from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections,
29 DatasetTypeName, IncompatibleGraphError)
30import lsst.pipe.base.connectionTypes as cT
31from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config
32from lsst.pex.config import Field
33from lsst.pipe.base.graph.quantumNode import NodeId, BuildId
34import lsst.utils.tests
37class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
38 initOutput = cT.InitOutput(name="Dummy1InitOutput",
39 storageClass="ExposureF",
40 doc="n/a")
41 input = cT.Input(name="Dummy1Input",
42 storageClass="ExposureF",
43 doc="n/a",
44 dimensions=("A", "B"))
45 output = cT.Output(name="Dummy1Output",
46 storageClass="ExposureF",
47 doc="n/a",
48 dimensions=("A", "B"))
51class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
52 conf1 = Field(dtype=int, default=1, doc="dummy config")
55class Dummy1PipelineTask(PipelineTask):
56 ConfigClass = Dummy1Config
59class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
60 initInput = cT.InitInput(name="Dummy1InitOutput",
61 storageClass="ExposureF",
62 doc="n/a")
63 initOutput = cT.InitOutput(name="Dummy2InitOutput",
64 storageClass="ExposureF",
65 doc="n/a")
66 input = cT.Input(name="Dummy1Output",
67 storageClass="ExposureF",
68 doc="n/a",
69 dimensions=("A", "B"))
70 output = cT.Output(name="Dummy2Output",
71 storageClass="ExposureF",
72 doc="n/a",
73 dimensions=("A", "B"))
76class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
77 conf1 = Field(dtype=int, default=1, doc="dummy config")
80class Dummy2PipelineTask(PipelineTask):
81 ConfigClass = Dummy2Config
84class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
85 initInput = cT.InitInput(name="Dummy2InitOutput",
86 storageClass="ExposureF",
87 doc="n/a")
88 initOutput = cT.InitOutput(name="Dummy3InitOutput",
89 storageClass="ExposureF",
90 doc="n/a")
91 input = cT.Input(name="Dummy2Output",
92 storageClass="ExposureF",
93 doc="n/a",
94 dimensions=("A", "B"))
95 output = cT.Output(name="Dummy3Output",
96 storageClass="ExposureF",
97 doc="n/a",
98 dimensions=("A", "B"))
101class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
102 conf1 = Field(dtype=int, default=1, doc="dummy config")
105class Dummy3PipelineTask(PipelineTask):
106 ConfigClass = Dummy3Config
109class QuantumGraphTestCase(unittest.TestCase):
110 """Tests the various functions of a quantum graph
111 """
112 def setUp(self):
113 config = Config({
114 "version": 1,
115 "skypix": {
116 "common": "htm7",
117 "htm": {
118 "class": "lsst.sphgeom.HtmPixelization",
119 "max_level": 24,
120 }
121 },
122 "elements": {
123 "A": {
124 "keys": [{
125 "name": "id",
126 "type": "int",
127 }],
128 },
129 "B": {
130 "keys": [{
131 "name": "id",
132 "type": "int",
133 }],
134 }
135 },
136 "packers": {}
137 })
138 universe = DimensionUniverse(config=config)
139 # need to make a mapping of TaskDef to set of quantum
140 quantumMap = {}
141 tasks = []
142 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")):
143 config = task.ConfigClass()
144 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label)
145 tasks.append(taskDef)
146 quantumSet = set()
147 connections = taskDef.connections
148 for a, b in ((1, 2), (3, 4)):
149 if connections.initInputs:
150 initInputDSType = DatasetType(connections.initInput.name,
151 tuple(),
152 storageClass=connections.initInput.storageClass,
153 universe=universe)
154 initRefs = [DatasetRef(initInputDSType,
155 DataCoordinate.makeEmpty(universe))]
156 else:
157 initRefs = None
158 inputDSType = DatasetType(connections.input.name,
159 connections.input.dimensions,
160 storageClass=connections.input.storageClass,
161 universe=universe,
162 )
163 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b},
164 universe=universe))]
165 outputDSType = DatasetType(connections.output.name,
166 connections.output.dimensions,
167 storageClass=connections.output.storageClass,
168 universe=universe,
169 )
170 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b},
171 universe=universe))]
172 quantumSet.add(
173 Quantum(taskName=task.__qualname__,
174 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
175 taskClass=task,
176 initInputs=initRefs,
177 inputs={inputDSType: inputRefs},
178 outputs={outputDSType: outputRefs}
179 )
180 )
181 quantumMap[taskDef] = quantumSet
182 self.tasks = tasks
183 self.quantumMap = quantumMap
184 self.qGraph = QuantumGraph(quantumMap)
185 self.universe = universe
187 def testTaskGraph(self):
188 for taskDef in self.quantumMap.keys():
189 self.assertIn(taskDef, self.qGraph.taskGraph)
191 def testGraph(self):
192 graphSet = {q.quantum for q in self.qGraph.graph}
193 for quantum in chain.from_iterable(self.quantumMap.values()):
194 self.assertIn(quantum, graphSet)
196 def testGetQuantumNodeByNodeId(self):
197 inputQuanta = tuple(self.qGraph.inputQuanta)
198 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
199 self.assertEqual(node, inputQuanta[0])
200 wrongNode = NodeId(15, BuildId("alternative build Id"))
201 with self.assertRaises(IncompatibleGraphError):
202 self.qGraph.getQuantumNodeByNodeId(wrongNode)
204 def testPickle(self):
205 stringify = pickle.dumps(self.qGraph)
206 restore: QuantumGraph = pickle.loads(stringify)
207 # This is a hack for the unit test since the qualified name will be
208 # different as it will be __main__ here, but qualified to the
209 # unittest module name when restored
210 for saved, loaded in zip(self.qGraph._quanta.keys(),
211 restore._quanta.keys()):
212 saved.taskName = saved.taskName.split('.')[-1]
213 loaded.taskName = loaded.taskName.split('.')[-1]
214 self.assertEqual(self.qGraph, restore)
216 def testInputQuanta(self):
217 inputs = {q.quantum for q in self.qGraph.inputQuanta}
218 self.assertEqual(self.quantumMap[self.tasks[0]], inputs)
220 def testOutputtQuanta(self):
221 outputs = {q.quantum for q in self.qGraph.outputQuanta}
222 self.assertEqual(self.quantumMap[self.tasks[-1]], outputs)
224 def testLength(self):
225 self.assertEqual(len(self.qGraph), 6)
227 def testGetQuantaForTask(self):
228 for task in self.tasks:
229 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
231 def testFindTasksWithInput(self):
232 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0],
233 self.tasks[1])
235 def testFindTasksWithOutput(self):
236 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
238 def testTaskWithDSType(self):
239 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))),
240 set(self.tasks[:2]))
242 def testFindTaskDefByName(self):
243 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0],
244 self.tasks[0])
246 def testFindTaskDefByLabel(self):
247 self.assertEqual(self.qGraph.findTaskDefByLabel("R"),
248 self.tasks[0])
250 def testFindQuantaWIthDSType(self):
251 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")),
252 self.quantumMap[self.tasks[0]])
254 def testAllDatasetTypes(self):
255 allDatasetTypes = set(self.qGraph.allDatasetTypes)
256 truth = set()
257 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections):
258 for connection in conClass.allConnections.values(): # type: ignore
259 truth.add(connection.name)
260 self.assertEqual(allDatasetTypes, truth)
262 def testSubset(self):
263 allNodes = list(self.qGraph)
264 subset = self.qGraph.subset(allNodes[0])
265 self.assertEqual(len(subset), 1)
266 subsetList = list(subset)
267 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
268 self.assertEqual(self.qGraph._buildId, subset._buildId)
270 def testIsConnected(self):
271 # False because there are two quantum chains for two distinct sets of
272 # dimensions
273 self.assertFalse(self.qGraph.isConnected)
274 # make a broken subset
275 allNodes = list(self.qGraph)
276 subset = self.qGraph.subset((allNodes[0], allNodes[1]))
277 # True because we subset to only one chain of graphs
278 self.assertTrue(subset.isConnected)
280 def testSubsetToConnected(self):
281 connectedGraphs = self.qGraph.subsetToConnected()
282 self.assertEqual(len(connectedGraphs), 2)
283 self.assertTrue(connectedGraphs[0].isConnected)
284 self.assertTrue(connectedGraphs[1].isConnected)
286 self.assertEqual(len(connectedGraphs[0]), 3)
287 self.assertEqual(len(connectedGraphs[1]), 3)
289 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
291 count = 0
292 for node in self.qGraph:
293 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
294 count += 1
295 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
296 count += 1
297 self.assertEqual(len(self.qGraph), count)
299 self.assertEqual(self.tasks, list(connectedGraphs[0].taskGraph))
300 self.assertEqual(self.tasks, list(connectedGraphs[1].taskGraph))
301 allNodes = list(self.qGraph)
302 node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
303 self.assertEqual(set([allNodes[0]]), node)
304 node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
305 self.assertEqual(set([allNodes[0]]), node)
307 def testDetermineOutputsOfQuantumNode(self):
308 allNodes = list(self.qGraph)
309 node = next(iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1])))
310 self.assertEqual(allNodes[2], node)
312 def testDetermineConnectionsOfQuantum(self):
313 allNodes = list(self.qGraph)
314 connections = self.qGraph.determineConnectionsOfQuantumNode(allNodes[1])
315 self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3])))
317 def testDetermineAnsestorsOfQuantumNode(self):
318 allNodes = list(self.qGraph)
319 ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2])
320 self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3])))
322 def testFindCycle(self):
323 self.assertFalse(self.qGraph.findCycle())
325 def testSaveLoad(self):
326 with tempfile.TemporaryFile() as tmpFile:
327 self.qGraph.save(tmpFile)
328 tmpFile.seek(0)
329 restore = QuantumGraph.load(tmpFile, self.universe)
330 # This is a hack for the unit test since the qualified name will be
331 # different as it will be __main__ here, but qualified to the
332 # unittest module name when restored
333 for saved, loaded in zip(self.qGraph._quanta.keys(),
334 restore._quanta.keys()):
335 saved.taskName = saved.taskName.split('.')[-1]
336 loaded.taskName = loaded.taskName.split('.')[-1]
337 self.assertEqual(self.qGraph, restore)
339 def testContains(self):
340 firstNode = next(iter(self.qGraph))
341 self.assertIn(firstNode, self.qGraph)
344class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
345 pass
348def setup_module(module):
349 lsst.utils.tests.init()
352if __name__ == "__main__": 352 ↛ 353line 352 didn't jump to line 353, because the condition on line 352 was never true
353 lsst.utils.tests.init()
354 unittest.main()