Coverage for tests/test_quantumGraph.py : 30%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from itertools import chain
23import os
24import pickle
25import tempfile
26import unittest
27from lsst.daf.butler import DimensionUniverse
29from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections,
30 DatasetTypeName, IncompatibleGraphError)
31import lsst.pipe.base.connectionTypes as cT
32from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config
33from lsst.pex.config import Field
34from lsst.pipe.base.graph.quantumNode import NodeId, BuildId
35import lsst.utils.tests
38class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
39 initOutput = cT.InitOutput(name="Dummy1InitOutput",
40 storageClass="ExposureF",
41 doc="n/a")
42 input = cT.Input(name="Dummy1Input",
43 storageClass="ExposureF",
44 doc="n/a",
45 dimensions=("A", "B"))
46 output = cT.Output(name="Dummy1Output",
47 storageClass="ExposureF",
48 doc="n/a",
49 dimensions=("A", "B"))
52class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
53 conf1 = Field(dtype=int, default=1, doc="dummy config")
56class Dummy1PipelineTask(PipelineTask):
57 ConfigClass = Dummy1Config
60class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
61 initInput = cT.InitInput(name="Dummy1InitOutput",
62 storageClass="ExposureF",
63 doc="n/a")
64 initOutput = cT.InitOutput(name="Dummy2InitOutput",
65 storageClass="ExposureF",
66 doc="n/a")
67 input = cT.Input(name="Dummy1Output",
68 storageClass="ExposureF",
69 doc="n/a",
70 dimensions=("A", "B"))
71 output = cT.Output(name="Dummy2Output",
72 storageClass="ExposureF",
73 doc="n/a",
74 dimensions=("A", "B"))
77class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
78 conf1 = Field(dtype=int, default=1, doc="dummy config")
81class Dummy2PipelineTask(PipelineTask):
82 ConfigClass = Dummy2Config
85class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
86 initInput = cT.InitInput(name="Dummy2InitOutput",
87 storageClass="ExposureF",
88 doc="n/a")
89 initOutput = cT.InitOutput(name="Dummy3InitOutput",
90 storageClass="ExposureF",
91 doc="n/a")
92 input = cT.Input(name="Dummy2Output",
93 storageClass="ExposureF",
94 doc="n/a",
95 dimensions=("A", "B"))
96 output = cT.Output(name="Dummy3Output",
97 storageClass="ExposureF",
98 doc="n/a",
99 dimensions=("A", "B"))
102class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
103 conf1 = Field(dtype=int, default=1, doc="dummy config")
106class Dummy3PipelineTask(PipelineTask):
107 ConfigClass = Dummy3Config
110class QuantumGraphTestCase(unittest.TestCase):
111 """Tests the various functions of a quantum graph
112 """
113 def setUp(self):
114 config = Config({
115 "version": 1,
116 "skypix": {
117 "common": "htm7",
118 "htm": {
119 "class": "lsst.sphgeom.HtmPixelization",
120 "max_level": 24,
121 }
122 },
123 "elements": {
124 "A": {
125 "keys": [{
126 "name": "id",
127 "type": "int",
128 }],
129 "storage": {
130 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
131 },
132 },
133 "B": {
134 "keys": [{
135 "name": "id",
136 "type": "int",
137 }],
138 "storage": {
139 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
140 },
141 }
142 },
143 "packers": {}
144 })
145 universe = DimensionUniverse(config=config)
146 # need to make a mapping of TaskDef to set of quantum
147 quantumMap = {}
148 tasks = []
149 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")):
150 config = task.ConfigClass()
151 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label)
152 tasks.append(taskDef)
153 quantumSet = set()
154 connections = taskDef.connections
155 for a, b in ((1, 2), (3, 4)):
156 if connections.initInputs:
157 initInputDSType = DatasetType(connections.initInput.name,
158 tuple(),
159 storageClass=connections.initInput.storageClass,
160 universe=universe)
161 initRefs = [DatasetRef(initInputDSType,
162 DataCoordinate.makeEmpty(universe))]
163 else:
164 initRefs = None
165 inputDSType = DatasetType(connections.input.name,
166 connections.input.dimensions,
167 storageClass=connections.input.storageClass,
168 universe=universe,
169 )
170 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b},
171 universe=universe))]
172 outputDSType = DatasetType(connections.output.name,
173 connections.output.dimensions,
174 storageClass=connections.output.storageClass,
175 universe=universe,
176 )
177 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b},
178 universe=universe))]
179 quantumSet.add(
180 Quantum(taskName=task.__qualname__,
181 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
182 taskClass=task,
183 initInputs=initRefs,
184 inputs={inputDSType: inputRefs},
185 outputs={outputDSType: outputRefs}
186 )
187 )
188 quantumMap[taskDef] = quantumSet
189 self.tasks = tasks
190 self.quantumMap = quantumMap
191 self.qGraph = QuantumGraph(quantumMap)
192 self.universe = universe
194 def _cleanGraphs(self, graph1, graph2):
195 # This is a hack for the unit test since the qualified name will be
196 # different as it will be __main__ here, but qualified to the
197 # unittest module name when restored
198 # Updates in place
199 for saved, loaded in zip(graph1._quanta.keys(),
200 graph2._quanta.keys()):
201 saved.taskName = saved.taskName.split('.')[-1]
202 loaded.taskName = loaded.taskName.split('.')[-1]
204 def testTaskGraph(self):
205 for taskDef in self.quantumMap.keys():
206 self.assertIn(taskDef, self.qGraph.taskGraph)
208 def testGraph(self):
209 graphSet = {q.quantum for q in self.qGraph.graph}
210 for quantum in chain.from_iterable(self.quantumMap.values()):
211 self.assertIn(quantum, graphSet)
213 def testGetQuantumNodeByNodeId(self):
214 inputQuanta = tuple(self.qGraph.inputQuanta)
215 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
216 self.assertEqual(node, inputQuanta[0])
217 wrongNode = NodeId(15, BuildId("alternative build Id"))
218 with self.assertRaises(IncompatibleGraphError):
219 self.qGraph.getQuantumNodeByNodeId(wrongNode)
221 def testPickle(self):
222 stringify = pickle.dumps(self.qGraph)
223 restore: QuantumGraph = pickle.loads(stringify)
224 self._cleanGraphs(self.qGraph, restore)
225 self.assertEqual(self.qGraph, restore)
227 def testInputQuanta(self):
228 inputs = {q.quantum for q in self.qGraph.inputQuanta}
229 self.assertEqual(self.quantumMap[self.tasks[0]], inputs)
231 def testOutputtQuanta(self):
232 outputs = {q.quantum for q in self.qGraph.outputQuanta}
233 self.assertEqual(self.quantumMap[self.tasks[-1]], outputs)
235 def testLength(self):
236 self.assertEqual(len(self.qGraph), 6)
238 def testGetQuantaForTask(self):
239 for task in self.tasks:
240 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
242 def testFindTasksWithInput(self):
243 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0],
244 self.tasks[1])
246 def testFindTasksWithOutput(self):
247 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
249 def testTaskWithDSType(self):
250 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))),
251 set(self.tasks[:2]))
253 def testFindTaskDefByName(self):
254 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0],
255 self.tasks[0])
257 def testFindTaskDefByLabel(self):
258 self.assertEqual(self.qGraph.findTaskDefByLabel("R"),
259 self.tasks[0])
261 def testFindQuantaWIthDSType(self):
262 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")),
263 self.quantumMap[self.tasks[0]])
265 def testAllDatasetTypes(self):
266 allDatasetTypes = set(self.qGraph.allDatasetTypes)
267 truth = set()
268 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections):
269 for connection in conClass.allConnections.values(): # type: ignore
270 truth.add(connection.name)
271 self.assertEqual(allDatasetTypes, truth)
273 def testSubset(self):
274 allNodes = list(self.qGraph)
275 subset = self.qGraph.subset(allNodes[0])
276 self.assertEqual(len(subset), 1)
277 subsetList = list(subset)
278 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
279 self.assertEqual(self.qGraph._buildId, subset._buildId)
281 def testIsConnected(self):
282 # False because there are two quantum chains for two distinct sets of
283 # dimensions
284 self.assertFalse(self.qGraph.isConnected)
285 # make a broken subset
286 allNodes = list(self.qGraph)
287 subset = self.qGraph.subset((allNodes[0], allNodes[1]))
288 # True because we subset to only one chain of graphs
289 self.assertTrue(subset.isConnected)
291 def testSubsetToConnected(self):
292 connectedGraphs = self.qGraph.subsetToConnected()
293 self.assertEqual(len(connectedGraphs), 2)
294 self.assertTrue(connectedGraphs[0].isConnected)
295 self.assertTrue(connectedGraphs[1].isConnected)
297 self.assertEqual(len(connectedGraphs[0]), 3)
298 self.assertEqual(len(connectedGraphs[1]), 3)
300 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
302 count = 0
303 for node in self.qGraph:
304 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
305 count += 1
306 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
307 count += 1
308 self.assertEqual(len(self.qGraph), count)
310 self.assertEqual(self.tasks, list(connectedGraphs[0].taskGraph))
311 self.assertEqual(self.tasks, list(connectedGraphs[1].taskGraph))
312 allNodes = list(self.qGraph)
313 node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
314 self.assertEqual(set([allNodes[0]]), node)
315 node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
316 self.assertEqual(set([allNodes[0]]), node)
318 def testDetermineOutputsOfQuantumNode(self):
319 allNodes = list(self.qGraph)
320 node = next(iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1])))
321 self.assertEqual(allNodes[2], node)
323 def testDetermineConnectionsOfQuantum(self):
324 allNodes = list(self.qGraph)
325 connections = self.qGraph.determineConnectionsOfQuantumNode(allNodes[1])
326 self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3])))
328 def testDetermineAnsestorsOfQuantumNode(self):
329 allNodes = list(self.qGraph)
330 ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2])
331 self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3])))
333 def testFindCycle(self):
334 self.assertFalse(self.qGraph.findCycle())
336 def testSaveLoad(self):
337 with tempfile.TemporaryFile() as tmpFile:
338 self.qGraph.save(tmpFile)
339 tmpFile.seek(0)
340 restore = QuantumGraph.load(tmpFile, self.universe)
341 self._cleanGraphs(self.qGraph, restore)
342 self.assertEqual(self.qGraph, restore)
344 def testSaveLoadUri(self):
345 uri = None
346 try:
347 with tempfile.NamedTemporaryFile(delete=False, suffix=".pickle") as tmpFile:
348 uri = tmpFile.name
349 self.qGraph.saveUri(uri)
350 restore = QuantumGraph.loadUri(uri, self.universe)
351 self._cleanGraphs(self.qGraph, restore)
352 self.assertEqual(self.qGraph, restore)
353 except Exception as e:
354 raise e
355 finally:
356 if uri is not None:
357 os.remove(uri)
359 with self.assertRaises(TypeError):
360 self.qGraph.saveUri("test.notpickle")
362 def testContains(self):
363 firstNode = next(iter(self.qGraph))
364 self.assertIn(firstNode, self.qGraph)
367class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
368 pass
371def setup_module(module):
372 lsst.utils.tests.init()
375if __name__ == "__main__": 375 ↛ 376line 375 didn't jump to line 376, because the condition on line 375 was never true
376 lsst.utils.tests.init()
377 unittest.main()