Coverage for tests/test_quantumGraph.py : 28%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from itertools import chain
23import os
24import pickle
25import tempfile
26from typing import Iterable
27import unittest
28import random
29from lsst.daf.butler import DimensionUniverse
31from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections,
32 DatasetTypeName, IncompatibleGraphError)
33import lsst.pipe.base.connectionTypes as cT
34from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config
35from lsst.pex.config import Field
36from lsst.pipe.base.graph.quantumNode import NodeId, BuildId, QuantumNode
37import lsst.utils.tests
39try:
40 import boto3
41 from moto import mock_s3
42except ImportError:
43 boto3 = None
45 def mock_s3(cls):
46 """A no-op decorator in case moto mock_s3 can not be imported.
47 """
48 return cls
51class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
52 initOutput = cT.InitOutput(name="Dummy1InitOutput",
53 storageClass="ExposureF",
54 doc="n/a")
55 input = cT.Input(name="Dummy1Input",
56 storageClass="ExposureF",
57 doc="n/a",
58 dimensions=("A", "B"))
59 output = cT.Output(name="Dummy1Output",
60 storageClass="ExposureF",
61 doc="n/a",
62 dimensions=("A", "B"))
65class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
66 conf1 = Field(dtype=int, default=1, doc="dummy config")
69class Dummy1PipelineTask(PipelineTask):
70 ConfigClass = Dummy1Config
73class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
74 initInput = cT.InitInput(name="Dummy1InitOutput",
75 storageClass="ExposureF",
76 doc="n/a")
77 initOutput = cT.InitOutput(name="Dummy2InitOutput",
78 storageClass="ExposureF",
79 doc="n/a")
80 input = cT.Input(name="Dummy1Output",
81 storageClass="ExposureF",
82 doc="n/a",
83 dimensions=("A", "B"))
84 output = cT.Output(name="Dummy2Output",
85 storageClass="ExposureF",
86 doc="n/a",
87 dimensions=("A", "B"))
90class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
91 conf1 = Field(dtype=int, default=1, doc="dummy config")
94class Dummy2PipelineTask(PipelineTask):
95 ConfigClass = Dummy2Config
98class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
99 initInput = cT.InitInput(name="Dummy2InitOutput",
100 storageClass="ExposureF",
101 doc="n/a")
102 initOutput = cT.InitOutput(name="Dummy3InitOutput",
103 storageClass="ExposureF",
104 doc="n/a")
105 input = cT.Input(name="Dummy2Output",
106 storageClass="ExposureF",
107 doc="n/a",
108 dimensions=("A", "B"))
109 output = cT.Output(name="Dummy3Output",
110 storageClass="ExposureF",
111 doc="n/a",
112 dimensions=("A", "B"))
115class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
116 conf1 = Field(dtype=int, default=1, doc="dummy config")
119class Dummy3PipelineTask(PipelineTask):
120 ConfigClass = Dummy3Config
123class QuantumGraphTestCase(unittest.TestCase):
124 """Tests the various functions of a quantum graph
125 """
126 def setUp(self):
127 config = Config({
128 "version": 1,
129 "skypix": {
130 "common": "htm7",
131 "htm": {
132 "class": "lsst.sphgeom.HtmPixelization",
133 "max_level": 24,
134 }
135 },
136 "elements": {
137 "A": {
138 "keys": [{
139 "name": "id",
140 "type": "int",
141 }],
142 "storage": {
143 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
144 },
145 },
146 "B": {
147 "keys": [{
148 "name": "id",
149 "type": "int",
150 }],
151 "storage": {
152 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
153 },
154 }
155 },
156 "packers": {}
157 })
158 universe = DimensionUniverse(config=config)
159 # need to make a mapping of TaskDef to set of quantum
160 quantumMap = {}
161 tasks = []
162 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")):
163 config = task.ConfigClass()
164 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label)
165 tasks.append(taskDef)
166 quantumSet = set()
167 connections = taskDef.connections
168 for a, b in ((1, 2), (3, 4)):
169 if connections.initInputs:
170 initInputDSType = DatasetType(connections.initInput.name,
171 tuple(),
172 storageClass=connections.initInput.storageClass,
173 universe=universe)
174 initRefs = [DatasetRef(initInputDSType,
175 DataCoordinate.makeEmpty(universe))]
176 else:
177 initRefs = None
178 inputDSType = DatasetType(connections.input.name,
179 connections.input.dimensions,
180 storageClass=connections.input.storageClass,
181 universe=universe,
182 )
183 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b},
184 universe=universe))]
185 outputDSType = DatasetType(connections.output.name,
186 connections.output.dimensions,
187 storageClass=connections.output.storageClass,
188 universe=universe,
189 )
190 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b},
191 universe=universe))]
192 quantumSet.add(
193 Quantum(taskName=task.__qualname__,
194 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
195 taskClass=task,
196 initInputs=initRefs,
197 inputs={inputDSType: inputRefs},
198 outputs={outputDSType: outputRefs}
199 )
200 )
201 quantumMap[taskDef] = quantumSet
202 self.tasks = tasks
203 self.quantumMap = quantumMap
204 self.qGraph = QuantumGraph(quantumMap)
205 self.universe = universe
207 def _cleanGraphs(self, graph1, graph2):
208 # This is a hack for the unit test since the qualified name will be
209 # different as it will be __main__ here, but qualified to the
210 # unittest module name when restored
211 # Updates in place
212 for saved, loaded in zip(graph1._quanta.keys(),
213 graph2._quanta.keys()):
214 saved.taskName = saved.taskName.split('.')[-1]
215 loaded.taskName = loaded.taskName.split('.')[-1]
217 def testTaskGraph(self):
218 for taskDef in self.quantumMap.keys():
219 self.assertIn(taskDef, self.qGraph.taskGraph)
221 def testGraph(self):
222 graphSet = {q.quantum for q in self.qGraph.graph}
223 for quantum in chain.from_iterable(self.quantumMap.values()):
224 self.assertIn(quantum, graphSet)
226 def testGetQuantumNodeByNodeId(self):
227 inputQuanta = tuple(self.qGraph.inputQuanta)
228 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
229 self.assertEqual(node, inputQuanta[0])
230 wrongNode = NodeId(15, BuildId("alternative build Id"))
231 with self.assertRaises(IncompatibleGraphError):
232 self.qGraph.getQuantumNodeByNodeId(wrongNode)
234 def testPickle(self):
235 stringify = pickle.dumps(self.qGraph)
236 restore: QuantumGraph = pickle.loads(stringify)
237 self._cleanGraphs(self.qGraph, restore)
238 self.assertEqual(self.qGraph, restore)
240 def testInputQuanta(self):
241 inputs = {q.quantum for q in self.qGraph.inputQuanta}
242 self.assertEqual(self.quantumMap[self.tasks[0]], inputs)
244 def testOutputtQuanta(self):
245 outputs = {q.quantum for q in self.qGraph.outputQuanta}
246 self.assertEqual(self.quantumMap[self.tasks[-1]], outputs)
248 def testLength(self):
249 self.assertEqual(len(self.qGraph), 6)
251 def testGetQuantaForTask(self):
252 for task in self.tasks:
253 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
255 def testGetNodesForTask(self):
256 for task in self.tasks:
257 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task)
258 quanta_in_node = set(n.quantum for n in nodes)
259 self.assertEqual(quanta_in_node, self.quantumMap[task])
261 def testFindTasksWithInput(self):
262 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0],
263 self.tasks[1])
265 def testFindTasksWithOutput(self):
266 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
268 def testTaskWithDSType(self):
269 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))),
270 set(self.tasks[:2]))
272 def testFindTaskDefByName(self):
273 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0],
274 self.tasks[0])
276 def testFindTaskDefByLabel(self):
277 self.assertEqual(self.qGraph.findTaskDefByLabel("R"),
278 self.tasks[0])
280 def testFindQuantaWIthDSType(self):
281 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")),
282 self.quantumMap[self.tasks[0]])
284 def testAllDatasetTypes(self):
285 allDatasetTypes = set(self.qGraph.allDatasetTypes)
286 truth = set()
287 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections):
288 for connection in conClass.allConnections.values(): # type: ignore
289 truth.add(connection.name)
290 self.assertEqual(allDatasetTypes, truth)
292 def testSubset(self):
293 allNodes = list(self.qGraph)
294 subset = self.qGraph.subset(allNodes[0])
295 self.assertEqual(len(subset), 1)
296 subsetList = list(subset)
297 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
298 self.assertEqual(self.qGraph._buildId, subset._buildId)
300 def testIsConnected(self):
301 # False because there are two quantum chains for two distinct sets of
302 # dimensions
303 self.assertFalse(self.qGraph.isConnected)
304 # make a broken subset
305 allNodes = list(self.qGraph)
306 subset = self.qGraph.subset((allNodes[0], allNodes[1]))
307 # True because we subset to only one chain of graphs
308 self.assertTrue(subset.isConnected)
310 def testSubsetToConnected(self):
311 connectedGraphs = self.qGraph.subsetToConnected()
312 self.assertEqual(len(connectedGraphs), 2)
313 self.assertTrue(connectedGraphs[0].isConnected)
314 self.assertTrue(connectedGraphs[1].isConnected)
316 self.assertEqual(len(connectedGraphs[0]), 3)
317 self.assertEqual(len(connectedGraphs[1]), 3)
319 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
321 count = 0
322 for node in self.qGraph:
323 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
324 count += 1
325 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
326 count += 1
327 self.assertEqual(len(self.qGraph), count)
329 self.assertEqual(self.tasks, list(connectedGraphs[0].taskGraph))
330 self.assertEqual(self.tasks, list(connectedGraphs[1].taskGraph))
331 allNodes = list(self.qGraph)
332 node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
333 self.assertEqual(set([allNodes[0]]), node)
334 node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
335 self.assertEqual(set([allNodes[0]]), node)
337 def testDetermineOutputsOfQuantumNode(self):
338 allNodes = list(self.qGraph)
339 node = next(iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1])))
340 self.assertEqual(allNodes[2], node)
342 def testDetermineConnectionsOfQuantum(self):
343 allNodes = list(self.qGraph)
344 connections = self.qGraph.determineConnectionsOfQuantumNode(allNodes[1])
345 self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3])))
347 def testDetermineAnsestorsOfQuantumNode(self):
348 allNodes = list(self.qGraph)
349 ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2])
350 self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3])))
352 def testFindCycle(self):
353 self.assertFalse(self.qGraph.findCycle())
355 def testSaveLoad(self):
356 with tempfile.TemporaryFile(suffix='.qgraph') as tmpFile:
357 self.qGraph.save(tmpFile)
358 tmpFile.seek(0)
359 restore = QuantumGraph.load(tmpFile, self.universe)
360 self._cleanGraphs(self.qGraph, restore)
361 self.assertEqual(self.qGraph, restore)
362 # Load in just one node
363 tmpFile.seek(0)
364 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(0,))
365 self.assertEqual(len(restoreSub), 1)
366 self.assertEqual(list(restoreSub)[0],
367 restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId)))
369 def testSaveLoadUri(self):
370 uri = None
371 try:
372 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile:
373 uri = tmpFile.name
374 self.qGraph.saveUri(uri)
375 restore = QuantumGraph.loadUri(uri, self.universe)
376 self._cleanGraphs(self.qGraph, restore)
377 self.assertEqual(self.qGraph, restore)
378 nodeNumber = random.randint(0, len(self.qGraph)-1)
379 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber,),
380 graphID=self.qGraph._buildId)
381 self.assertEqual(len(restoreSub), 1)
382 self.assertEqual(list(restoreSub)[0],
383 restore.getQuantumNodeByNodeId(NodeId(nodeNumber, restore.graphID)))
384 # verify that more than one node works
385 nodeNumber2 = random.randint(0, len(self.qGraph)-1)
386 # ensure it is a different node number
387 while nodeNumber2 == nodeNumber:
388 nodeNumber2 = random.randint(0, len(self.qGraph)-1)
389 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2))
390 self.assertEqual(len(restoreSub), 2)
391 self.assertEqual(set(restoreSub),
392 set((restore.getQuantumNodeByNodeId(NodeId(nodeNumber, restore._buildId)),
393 restore.getQuantumNodeByNodeId(NodeId(nodeNumber2, restore._buildId)))))
394 # verify an error when requesting a non existant node number
395 with self.assertRaises(ValueError):
396 QuantumGraph.loadUri(uri, self.universe, nodes=(99,))
398 # verify a graphID that does not match will be an error
399 with self.assertRaises(ValueError):
400 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT")
402 except Exception as e:
403 raise e
404 finally:
405 if uri is not None:
406 os.remove(uri)
408 with self.assertRaises(TypeError):
409 self.qGraph.saveUri("test.notgraph")
411 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!")
412 @mock_s3
413 def testSaveLoadUriS3(self):
414 # Test loading a quantum graph from an mock s3 store
415 conn = boto3.resource('s3', region_name="us-east-1")
416 conn.create_bucket(Bucket='testBucket')
417 uri = "s3://testBucket/qgraph.qgraph"
418 self.qGraph.saveUri(uri)
419 restore = QuantumGraph.loadUri(uri, self.universe)
420 self._cleanGraphs(self.qGraph, restore)
421 self.assertEqual(self.qGraph, restore)
422 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(0,))
423 self.assertEqual(len(restoreSub), 1)
424 self.assertEqual(list(restoreSub)[0],
425 restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId)))
427 def testContains(self):
428 firstNode = next(iter(self.qGraph))
429 self.assertIn(firstNode, self.qGraph)
432class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
433 pass
436def setup_module(module):
437 lsst.utils.tests.init()
440if __name__ == "__main__": 440 ↛ 441line 440 didn't jump to line 441, because the condition on line 440 was never true
441 lsst.utils.tests.init()
442 unittest.main()