Coverage for tests/test_quantumGraph.py: 28%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from itertools import chain
23import os
24import pickle
25import tempfile
26from typing import Iterable
27import unittest
28import random
29from lsst.daf.butler import DimensionUniverse
31from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections,
32 DatasetTypeName, IncompatibleGraphError)
33import lsst.pipe.base.connectionTypes as cT
34from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config
35from lsst.pex.config import Field
36from lsst.pipe.base.graph.quantumNode import NodeId, BuildId, QuantumNode
37import lsst.utils.tests
39try:
40 import boto3
41 from moto import mock_s3
42except ImportError:
43 boto3 = None
45 def mock_s3(cls):
46 """A no-op decorator in case moto mock_s3 can not be imported.
47 """
48 return cls
50METADATA = {'a': [1, 2, 3]}
53class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
54 initOutput = cT.InitOutput(name="Dummy1InitOutput",
55 storageClass="ExposureF",
56 doc="n/a")
57 input = cT.Input(name="Dummy1Input",
58 storageClass="ExposureF",
59 doc="n/a",
60 dimensions=("A", "B"))
61 output = cT.Output(name="Dummy1Output",
62 storageClass="ExposureF",
63 doc="n/a",
64 dimensions=("A", "B"))
67class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
68 conf1 = Field(dtype=int, default=1, doc="dummy config")
71class Dummy1PipelineTask(PipelineTask):
72 ConfigClass = Dummy1Config
75class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
76 initInput = cT.InitInput(name="Dummy1InitOutput",
77 storageClass="ExposureF",
78 doc="n/a")
79 initOutput = cT.InitOutput(name="Dummy2InitOutput",
80 storageClass="ExposureF",
81 doc="n/a")
82 input = cT.Input(name="Dummy1Output",
83 storageClass="ExposureF",
84 doc="n/a",
85 dimensions=("A", "B"))
86 output = cT.Output(name="Dummy2Output",
87 storageClass="ExposureF",
88 doc="n/a",
89 dimensions=("A", "B"))
92class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
93 conf1 = Field(dtype=int, default=1, doc="dummy config")
96class Dummy2PipelineTask(PipelineTask):
97 ConfigClass = Dummy2Config
100class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
101 initInput = cT.InitInput(name="Dummy2InitOutput",
102 storageClass="ExposureF",
103 doc="n/a")
104 initOutput = cT.InitOutput(name="Dummy3InitOutput",
105 storageClass="ExposureF",
106 doc="n/a")
107 input = cT.Input(name="Dummy2Output",
108 storageClass="ExposureF",
109 doc="n/a",
110 dimensions=("A", "B"))
111 output = cT.Output(name="Dummy3Output",
112 storageClass="ExposureF",
113 doc="n/a",
114 dimensions=("A", "B"))
117class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
118 conf1 = Field(dtype=int, default=1, doc="dummy config")
121class Dummy3PipelineTask(PipelineTask):
122 ConfigClass = Dummy3Config
125class QuantumGraphTestCase(unittest.TestCase):
126 """Tests the various functions of a quantum graph
127 """
128 def setUp(self):
129 config = Config({
130 "version": 1,
131 "skypix": {
132 "common": "htm7",
133 "htm": {
134 "class": "lsst.sphgeom.HtmPixelization",
135 "max_level": 24,
136 }
137 },
138 "elements": {
139 "A": {
140 "keys": [{
141 "name": "id",
142 "type": "int",
143 }],
144 "storage": {
145 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
146 },
147 },
148 "B": {
149 "keys": [{
150 "name": "id",
151 "type": "int",
152 }],
153 "storage": {
154 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
155 },
156 }
157 },
158 "packers": {}
159 })
160 universe = DimensionUniverse(config=config)
161 # need to make a mapping of TaskDef to set of quantum
162 quantumMap = {}
163 tasks = []
164 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")):
165 config = task.ConfigClass()
166 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label)
167 tasks.append(taskDef)
168 quantumSet = set()
169 connections = taskDef.connections
170 for a, b in ((1, 2), (3, 4)):
171 if connections.initInputs:
172 initInputDSType = DatasetType(connections.initInput.name,
173 tuple(),
174 storageClass=connections.initInput.storageClass,
175 universe=universe)
176 initRefs = [DatasetRef(initInputDSType,
177 DataCoordinate.makeEmpty(universe))]
178 else:
179 initRefs = None
180 inputDSType = DatasetType(connections.input.name,
181 connections.input.dimensions,
182 storageClass=connections.input.storageClass,
183 universe=universe,
184 )
185 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b},
186 universe=universe))]
187 outputDSType = DatasetType(connections.output.name,
188 connections.output.dimensions,
189 storageClass=connections.output.storageClass,
190 universe=universe,
191 )
192 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b},
193 universe=universe))]
194 quantumSet.add(
195 Quantum(taskName=task.__qualname__,
196 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
197 taskClass=task,
198 initInputs=initRefs,
199 inputs={inputDSType: inputRefs},
200 outputs={outputDSType: outputRefs}
201 )
202 )
203 quantumMap[taskDef] = quantumSet
204 self.tasks = tasks
205 self.quantumMap = quantumMap
206 self.qGraph = QuantumGraph(quantumMap, metadata=METADATA)
207 self.universe = universe
209 def _cleanGraphs(self, graph1, graph2):
210 # This is a hack for the unit test since the qualified name will be
211 # different as it will be __main__ here, but qualified to the
212 # unittest module name when restored
213 # Updates in place
214 for saved, loaded in zip(graph1._quanta.keys(),
215 graph2._quanta.keys()):
216 saved.taskName = saved.taskName.split('.')[-1]
217 loaded.taskName = loaded.taskName.split('.')[-1]
219 def testTaskGraph(self):
220 for taskDef in self.quantumMap.keys():
221 self.assertIn(taskDef, self.qGraph.taskGraph)
223 def testGraph(self):
224 graphSet = {q.quantum for q in self.qGraph.graph}
225 for quantum in chain.from_iterable(self.quantumMap.values()):
226 self.assertIn(quantum, graphSet)
228 def testGetQuantumNodeByNodeId(self):
229 inputQuanta = tuple(self.qGraph.inputQuanta)
230 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
231 self.assertEqual(node, inputQuanta[0])
232 wrongNode = NodeId(15, BuildId("alternative build Id"))
233 with self.assertRaises(IncompatibleGraphError):
234 self.qGraph.getQuantumNodeByNodeId(wrongNode)
236 def testPickle(self):
237 stringify = pickle.dumps(self.qGraph)
238 restore: QuantumGraph = pickle.loads(stringify)
239 self._cleanGraphs(self.qGraph, restore)
240 self.assertEqual(self.qGraph, restore)
242 def testInputQuanta(self):
243 inputs = {q.quantum for q in self.qGraph.inputQuanta}
244 self.assertEqual(self.quantumMap[self.tasks[0]], inputs)
246 def testOutputtQuanta(self):
247 outputs = {q.quantum for q in self.qGraph.outputQuanta}
248 self.assertEqual(self.quantumMap[self.tasks[-1]], outputs)
250 def testLength(self):
251 self.assertEqual(len(self.qGraph), 6)
253 def testGetQuantaForTask(self):
254 for task in self.tasks:
255 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
257 def testGetNodesForTask(self):
258 for task in self.tasks:
259 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task)
260 quanta_in_node = set(n.quantum for n in nodes)
261 self.assertEqual(quanta_in_node, self.quantumMap[task])
263 def testFindTasksWithInput(self):
264 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0],
265 self.tasks[1])
267 def testFindTasksWithOutput(self):
268 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
270 def testTaskWithDSType(self):
271 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))),
272 set(self.tasks[:2]))
274 def testFindTaskDefByName(self):
275 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0],
276 self.tasks[0])
278 def testFindTaskDefByLabel(self):
279 self.assertEqual(self.qGraph.findTaskDefByLabel("R"),
280 self.tasks[0])
282 def testFindQuantaWIthDSType(self):
283 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")),
284 self.quantumMap[self.tasks[0]])
286 def testAllDatasetTypes(self):
287 allDatasetTypes = set(self.qGraph.allDatasetTypes)
288 truth = set()
289 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections):
290 for connection in conClass.allConnections.values(): # type: ignore
291 if not isinstance(connection, cT.InitOutput):
292 truth.add(connection.name)
293 self.assertEqual(allDatasetTypes, truth)
295 def testSubset(self):
296 allNodes = list(self.qGraph)
297 subset = self.qGraph.subset(allNodes[0])
298 self.assertEqual(len(subset), 1)
299 subsetList = list(subset)
300 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
301 self.assertEqual(self.qGraph._buildId, subset._buildId)
303 def testIsConnected(self):
304 # False because there are two quantum chains for two distinct sets of
305 # dimensions
306 self.assertFalse(self.qGraph.isConnected)
307 # make a broken subset
308 allNodes = list(self.qGraph)
309 subset = self.qGraph.subset((allNodes[0], allNodes[1]))
310 # True because we subset to only one chain of graphs
311 self.assertTrue(subset.isConnected)
313 def testSubsetToConnected(self):
314 connectedGraphs = self.qGraph.subsetToConnected()
315 self.assertEqual(len(connectedGraphs), 2)
316 self.assertTrue(connectedGraphs[0].isConnected)
317 self.assertTrue(connectedGraphs[1].isConnected)
319 self.assertEqual(len(connectedGraphs[0]), 3)
320 self.assertEqual(len(connectedGraphs[1]), 3)
322 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
324 count = 0
325 for node in self.qGraph:
326 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
327 count += 1
328 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
329 count += 1
330 self.assertEqual(len(self.qGraph), count)
332 self.assertEqual(set(self.tasks), set(connectedGraphs[0].taskGraph))
333 self.assertEqual(set(self.tasks), set(connectedGraphs[1].taskGraph))
334 allNodes = list(self.qGraph)
335 node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
336 self.assertEqual(set([allNodes[0]]), node)
337 node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
338 self.assertEqual(set([allNodes[0]]), node)
340 def testDetermineOutputsOfQuantumNode(self):
341 allNodes = list(self.qGraph)
342 node = next(iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1])))
343 self.assertEqual(allNodes[2], node)
345 def testDetermineConnectionsOfQuantum(self):
346 allNodes = list(self.qGraph)
347 connections = self.qGraph.determineConnectionsOfQuantumNode(allNodes[1])
348 self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3])))
350 def testDetermineAnsestorsOfQuantumNode(self):
351 allNodes = list(self.qGraph)
352 ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2])
353 self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3])))
355 def testFindCycle(self):
356 self.assertFalse(self.qGraph.findCycle())
358 def testSaveLoad(self):
359 with tempfile.TemporaryFile(suffix='.qgraph') as tmpFile:
360 self.qGraph.save(tmpFile)
361 tmpFile.seek(0)
362 restore = QuantumGraph.load(tmpFile, self.universe)
363 self._cleanGraphs(self.qGraph, restore)
364 self.assertEqual(self.qGraph, restore)
365 # Load in just one node
366 tmpFile.seek(0)
367 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(0,))
368 self.assertEqual(len(restoreSub), 1)
369 self.assertEqual(list(restoreSub)[0],
370 restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId)))
372 def testSaveLoadUri(self):
373 uri = None
374 try:
375 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile:
376 uri = tmpFile.name
377 self.qGraph.saveUri(uri)
378 restore = QuantumGraph.loadUri(uri, self.universe)
379 self.assertEqual(restore.metadata, METADATA)
380 self._cleanGraphs(self.qGraph, restore)
381 self.assertEqual(self.qGraph, restore)
382 nodeNumber = random.randint(0, len(self.qGraph)-1)
383 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber,),
384 graphID=self.qGraph._buildId)
385 self.assertEqual(len(restoreSub), 1)
386 self.assertEqual(list(restoreSub)[0],
387 restore.getQuantumNodeByNodeId(NodeId(nodeNumber, restore.graphID)))
388 # verify that more than one node works
389 nodeNumber2 = random.randint(0, len(self.qGraph)-1)
390 # ensure it is a different node number
391 while nodeNumber2 == nodeNumber:
392 nodeNumber2 = random.randint(0, len(self.qGraph)-1)
393 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2))
394 self.assertEqual(len(restoreSub), 2)
395 self.assertEqual(set(restoreSub),
396 set((restore.getQuantumNodeByNodeId(NodeId(nodeNumber, restore._buildId)),
397 restore.getQuantumNodeByNodeId(NodeId(nodeNumber2, restore._buildId)))))
398 # verify an error when requesting a non existant node number
399 with self.assertRaises(ValueError):
400 QuantumGraph.loadUri(uri, self.universe, nodes=(99,))
402 # verify a graphID that does not match will be an error
403 with self.assertRaises(ValueError):
404 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT")
406 except Exception as e:
407 raise e
408 finally:
409 if uri is not None:
410 os.remove(uri)
412 with self.assertRaises(TypeError):
413 self.qGraph.saveUri("test.notgraph")
415 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!")
416 @mock_s3
417 def testSaveLoadUriS3(self):
418 # Test loading a quantum graph from an mock s3 store
419 conn = boto3.resource('s3', region_name="us-east-1")
420 conn.create_bucket(Bucket='testBucket')
421 uri = "s3://testBucket/qgraph.qgraph"
422 self.qGraph.saveUri(uri)
423 restore = QuantumGraph.loadUri(uri, self.universe)
424 self._cleanGraphs(self.qGraph, restore)
425 self.assertEqual(self.qGraph, restore)
426 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(0,))
427 self.assertEqual(len(restoreSub), 1)
428 self.assertEqual(list(restoreSub)[0],
429 restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId)))
431 def testContains(self):
432 firstNode = next(iter(self.qGraph))
433 self.assertIn(firstNode, self.qGraph)
436class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
437 pass
440def setup_module(module):
441 lsst.utils.tests.init()
444if __name__ == "__main__": 444 ↛ 445line 444 didn't jump to line 445, because the condition on line 444 was never true
445 lsst.utils.tests.init()
446 unittest.main()