Coverage for tests/test_quantumGraph.py : 31%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from itertools import chain
23import pickle
24import tempfile
25import unittest
26from lsst.daf.butler import DimensionUniverse
28from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections,
29 DatasetTypeName, IncompatibleGraphError)
30import lsst.pipe.base.connectionTypes as cT
31from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config
32from lsst.pex.config import Field
33from lsst.pipe.base.graph.quantumNode import NodeId, BuildId
34import lsst.utils.tests
37class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
38 initOutput = cT.InitOutput(name="Dummy1InitOutput",
39 storageClass="ExposureF",
40 doc="n/a")
41 input = cT.Input(name="Dummy1Input",
42 storageClass="ExposureF",
43 doc="n/a",
44 dimensions=("A", "B"))
45 output = cT.Output(name="Dummy1Output",
46 storageClass="ExposureF",
47 doc="n/a",
48 dimensions=("A", "B"))
51class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
52 conf1 = Field(dtype=int, default=1, doc="dummy config")
55class Dummy1PipelineTask(PipelineTask):
56 ConfigClass = Dummy1Config
59class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
60 initInput = cT.InitInput(name="Dummy1InitOutput",
61 storageClass="ExposureF",
62 doc="n/a")
63 initOutput = cT.InitOutput(name="Dummy2InitOutput",
64 storageClass="ExposureF",
65 doc="n/a")
66 input = cT.Input(name="Dummy1Output",
67 storageClass="ExposureF",
68 doc="n/a",
69 dimensions=("A", "B"))
70 output = cT.Output(name="Dummy2Output",
71 storageClass="ExposureF",
72 doc="n/a",
73 dimensions=("A", "B"))
76class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
77 conf1 = Field(dtype=int, default=1, doc="dummy config")
80class Dummy2PipelineTask(PipelineTask):
81 ConfigClass = Dummy2Config
84class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
85 initInput = cT.InitInput(name="Dummy2InitOutput",
86 storageClass="ExposureF",
87 doc="n/a")
88 initOutput = cT.InitOutput(name="Dummy3InitOutput",
89 storageClass="ExposureF",
90 doc="n/a")
91 input = cT.Input(name="Dummy2Output",
92 storageClass="ExposureF",
93 doc="n/a",
94 dimensions=("A", "B"))
95 output = cT.Output(name="Dummy3Output",
96 storageClass="ExposureF",
97 doc="n/a",
98 dimensions=("A", "B"))
101class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
102 conf1 = Field(dtype=int, default=1, doc="dummy config")
105class Dummy3PipelineTask(PipelineTask):
106 ConfigClass = Dummy3Config
109class QuantumGraphTestCase(unittest.TestCase):
110 """Tests the various functions of a quantum graph
111 """
112 def setUp(self):
113 config = Config({
114 "version": 1,
115 "skypix": {
116 "common": "htm7",
117 "htm": {
118 "class": "lsst.sphgeom.HtmPixelization",
119 "max_level": 24,
120 }
121 },
122 "elements": {
123 "A": {
124 "keys": [{
125 "name": "id",
126 "type": "int",
127 }],
128 "storage": {
129 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
130 },
131 },
132 "B": {
133 "keys": [{
134 "name": "id",
135 "type": "int",
136 }],
137 "storage": {
138 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
139 },
140 }
141 },
142 "packers": {}
143 })
144 universe = DimensionUniverse(config=config)
145 # need to make a mapping of TaskDef to set of quantum
146 quantumMap = {}
147 tasks = []
148 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")):
149 config = task.ConfigClass()
150 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label)
151 tasks.append(taskDef)
152 quantumSet = set()
153 connections = taskDef.connections
154 for a, b in ((1, 2), (3, 4)):
155 if connections.initInputs:
156 initInputDSType = DatasetType(connections.initInput.name,
157 tuple(),
158 storageClass=connections.initInput.storageClass,
159 universe=universe)
160 initRefs = [DatasetRef(initInputDSType,
161 DataCoordinate.makeEmpty(universe))]
162 else:
163 initRefs = None
164 inputDSType = DatasetType(connections.input.name,
165 connections.input.dimensions,
166 storageClass=connections.input.storageClass,
167 universe=universe,
168 )
169 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b},
170 universe=universe))]
171 outputDSType = DatasetType(connections.output.name,
172 connections.output.dimensions,
173 storageClass=connections.output.storageClass,
174 universe=universe,
175 )
176 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b},
177 universe=universe))]
178 quantumSet.add(
179 Quantum(taskName=task.__qualname__,
180 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
181 taskClass=task,
182 initInputs=initRefs,
183 inputs={inputDSType: inputRefs},
184 outputs={outputDSType: outputRefs}
185 )
186 )
187 quantumMap[taskDef] = quantumSet
188 self.tasks = tasks
189 self.quantumMap = quantumMap
190 self.qGraph = QuantumGraph(quantumMap)
191 self.universe = universe
193 def testTaskGraph(self):
194 for taskDef in self.quantumMap.keys():
195 self.assertIn(taskDef, self.qGraph.taskGraph)
197 def testGraph(self):
198 graphSet = {q.quantum for q in self.qGraph.graph}
199 for quantum in chain.from_iterable(self.quantumMap.values()):
200 self.assertIn(quantum, graphSet)
202 def testGetQuantumNodeByNodeId(self):
203 inputQuanta = tuple(self.qGraph.inputQuanta)
204 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
205 self.assertEqual(node, inputQuanta[0])
206 wrongNode = NodeId(15, BuildId("alternative build Id"))
207 with self.assertRaises(IncompatibleGraphError):
208 self.qGraph.getQuantumNodeByNodeId(wrongNode)
210 def testPickle(self):
211 stringify = pickle.dumps(self.qGraph)
212 restore: QuantumGraph = pickle.loads(stringify)
213 # This is a hack for the unit test since the qualified name will be
214 # different as it will be __main__ here, but qualified to the
215 # unittest module name when restored
216 for saved, loaded in zip(self.qGraph._quanta.keys(),
217 restore._quanta.keys()):
218 saved.taskName = saved.taskName.split('.')[-1]
219 loaded.taskName = loaded.taskName.split('.')[-1]
220 self.assertEqual(self.qGraph, restore)
222 def testInputQuanta(self):
223 inputs = {q.quantum for q in self.qGraph.inputQuanta}
224 self.assertEqual(self.quantumMap[self.tasks[0]], inputs)
226 def testOutputtQuanta(self):
227 outputs = {q.quantum for q in self.qGraph.outputQuanta}
228 self.assertEqual(self.quantumMap[self.tasks[-1]], outputs)
230 def testLength(self):
231 self.assertEqual(len(self.qGraph), 6)
233 def testGetQuantaForTask(self):
234 for task in self.tasks:
235 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
237 def testFindTasksWithInput(self):
238 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0],
239 self.tasks[1])
241 def testFindTasksWithOutput(self):
242 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
244 def testTaskWithDSType(self):
245 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))),
246 set(self.tasks[:2]))
248 def testFindTaskDefByName(self):
249 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0],
250 self.tasks[0])
252 def testFindTaskDefByLabel(self):
253 self.assertEqual(self.qGraph.findTaskDefByLabel("R"),
254 self.tasks[0])
256 def testFindQuantaWIthDSType(self):
257 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")),
258 self.quantumMap[self.tasks[0]])
260 def testAllDatasetTypes(self):
261 allDatasetTypes = set(self.qGraph.allDatasetTypes)
262 truth = set()
263 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections):
264 for connection in conClass.allConnections.values(): # type: ignore
265 truth.add(connection.name)
266 self.assertEqual(allDatasetTypes, truth)
268 def testSubset(self):
269 allNodes = list(self.qGraph)
270 subset = self.qGraph.subset(allNodes[0])
271 self.assertEqual(len(subset), 1)
272 subsetList = list(subset)
273 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
274 self.assertEqual(self.qGraph._buildId, subset._buildId)
276 def testIsConnected(self):
277 # False because there are two quantum chains for two distinct sets of
278 # dimensions
279 self.assertFalse(self.qGraph.isConnected)
280 # make a broken subset
281 allNodes = list(self.qGraph)
282 subset = self.qGraph.subset((allNodes[0], allNodes[1]))
283 # True because we subset to only one chain of graphs
284 self.assertTrue(subset.isConnected)
286 def testSubsetToConnected(self):
287 connectedGraphs = self.qGraph.subsetToConnected()
288 self.assertEqual(len(connectedGraphs), 2)
289 self.assertTrue(connectedGraphs[0].isConnected)
290 self.assertTrue(connectedGraphs[1].isConnected)
292 self.assertEqual(len(connectedGraphs[0]), 3)
293 self.assertEqual(len(connectedGraphs[1]), 3)
295 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
297 count = 0
298 for node in self.qGraph:
299 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
300 count += 1
301 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
302 count += 1
303 self.assertEqual(len(self.qGraph), count)
305 self.assertEqual(self.tasks, list(connectedGraphs[0].taskGraph))
306 self.assertEqual(self.tasks, list(connectedGraphs[1].taskGraph))
307 allNodes = list(self.qGraph)
308 node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
309 self.assertEqual(set([allNodes[0]]), node)
310 node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
311 self.assertEqual(set([allNodes[0]]), node)
313 def testDetermineOutputsOfQuantumNode(self):
314 allNodes = list(self.qGraph)
315 node = next(iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1])))
316 self.assertEqual(allNodes[2], node)
318 def testDetermineConnectionsOfQuantum(self):
319 allNodes = list(self.qGraph)
320 connections = self.qGraph.determineConnectionsOfQuantumNode(allNodes[1])
321 self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3])))
323 def testDetermineAnsestorsOfQuantumNode(self):
324 allNodes = list(self.qGraph)
325 ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2])
326 self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3])))
328 def testFindCycle(self):
329 self.assertFalse(self.qGraph.findCycle())
331 def testSaveLoad(self):
332 with tempfile.TemporaryFile() as tmpFile:
333 self.qGraph.save(tmpFile)
334 tmpFile.seek(0)
335 restore = QuantumGraph.load(tmpFile, self.universe)
336 # This is a hack for the unit test since the qualified name will be
337 # different as it will be __main__ here, but qualified to the
338 # unittest module name when restored
339 for saved, loaded in zip(self.qGraph._quanta.keys(),
340 restore._quanta.keys()):
341 saved.taskName = saved.taskName.split('.')[-1]
342 loaded.taskName = loaded.taskName.split('.')[-1]
343 self.assertEqual(self.qGraph, restore)
345 def testContains(self):
346 firstNode = next(iter(self.qGraph))
347 self.assertIn(firstNode, self.qGraph)
350class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
351 pass
354def setup_module(module):
355 lsst.utils.tests.init()
358if __name__ == "__main__": 358 ↛ 359line 358 didn't jump to line 359, because the condition on line 358 was never true
359 lsst.utils.tests.init()
360 unittest.main()