Coverage for tests/test_quantumGraph.py : 28%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from itertools import chain
23import os
24import pickle
25import tempfile
26from typing import Iterable
27import unittest
28import random
29from lsst.daf.butler import DimensionUniverse
31from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections,
32 DatasetTypeName, IncompatibleGraphError)
33import lsst.pipe.base.connectionTypes as cT
34from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config
35from lsst.pex.config import Field
36from lsst.pipe.base.graph.quantumNode import NodeId, BuildId, QuantumNode
37import lsst.utils.tests
39try:
40 import boto3
41 from moto import mock_s3
42except ImportError:
43 boto3 = None
45 def mock_s3(cls):
46 """A no-op decorator in case moto mock_s3 can not be imported.
47 """
48 return cls
50METADATA = {'a': [1, 2, 3]}
53class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
54 initOutput = cT.InitOutput(name="Dummy1InitOutput",
55 storageClass="ExposureF",
56 doc="n/a")
57 input = cT.Input(name="Dummy1Input",
58 storageClass="ExposureF",
59 doc="n/a",
60 dimensions=("A", "B"))
61 output = cT.Output(name="Dummy1Output",
62 storageClass="ExposureF",
63 doc="n/a",
64 dimensions=("A", "B"))
67class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
68 conf1 = Field(dtype=int, default=1, doc="dummy config")
71class Dummy1PipelineTask(PipelineTask):
72 ConfigClass = Dummy1Config
75class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
76 initInput = cT.InitInput(name="Dummy1InitOutput",
77 storageClass="ExposureF",
78 doc="n/a")
79 initOutput = cT.InitOutput(name="Dummy2InitOutput",
80 storageClass="ExposureF",
81 doc="n/a")
82 input = cT.Input(name="Dummy1Output",
83 storageClass="ExposureF",
84 doc="n/a",
85 dimensions=("A", "B"))
86 output = cT.Output(name="Dummy2Output",
87 storageClass="ExposureF",
88 doc="n/a",
89 dimensions=("A", "B"))
92class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
93 conf1 = Field(dtype=int, default=1, doc="dummy config")
96class Dummy2PipelineTask(PipelineTask):
97 ConfigClass = Dummy2Config
100class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
101 initInput = cT.InitInput(name="Dummy2InitOutput",
102 storageClass="ExposureF",
103 doc="n/a")
104 initOutput = cT.InitOutput(name="Dummy3InitOutput",
105 storageClass="ExposureF",
106 doc="n/a")
107 input = cT.Input(name="Dummy2Output",
108 storageClass="ExposureF",
109 doc="n/a",
110 dimensions=("A", "B"))
111 output = cT.Output(name="Dummy3Output",
112 storageClass="ExposureF",
113 doc="n/a",
114 dimensions=("A", "B"))
117class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
118 conf1 = Field(dtype=int, default=1, doc="dummy config")
121class Dummy3PipelineTask(PipelineTask):
122 ConfigClass = Dummy3Config
125class QuantumGraphTestCase(unittest.TestCase):
126 """Tests the various functions of a quantum graph
127 """
128 def setUp(self):
129 config = Config({
130 "version": 1,
131 "skypix": {
132 "common": "htm7",
133 "htm": {
134 "class": "lsst.sphgeom.HtmPixelization",
135 "max_level": 24,
136 }
137 },
138 "elements": {
139 "A": {
140 "keys": [{
141 "name": "id",
142 "type": "int",
143 }],
144 "storage": {
145 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
146 },
147 },
148 "B": {
149 "keys": [{
150 "name": "id",
151 "type": "int",
152 }],
153 "storage": {
154 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
155 },
156 }
157 },
158 "packers": {}
159 })
160 universe = DimensionUniverse(config=config)
161 # need to make a mapping of TaskDef to set of quantum
162 quantumMap = {}
163 tasks = []
164 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")):
165 config = task.ConfigClass()
166 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label)
167 tasks.append(taskDef)
168 quantumSet = set()
169 connections = taskDef.connections
170 for a, b in ((1, 2), (3, 4)):
171 if connections.initInputs:
172 initInputDSType = DatasetType(connections.initInput.name,
173 tuple(),
174 storageClass=connections.initInput.storageClass,
175 universe=universe)
176 initRefs = [DatasetRef(initInputDSType,
177 DataCoordinate.makeEmpty(universe))]
178 else:
179 initRefs = None
180 inputDSType = DatasetType(connections.input.name,
181 connections.input.dimensions,
182 storageClass=connections.input.storageClass,
183 universe=universe,
184 )
185 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b},
186 universe=universe))]
187 outputDSType = DatasetType(connections.output.name,
188 connections.output.dimensions,
189 storageClass=connections.output.storageClass,
190 universe=universe,
191 )
192 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b},
193 universe=universe))]
194 quantumSet.add(
195 Quantum(taskName=task.__qualname__,
196 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
197 taskClass=task,
198 initInputs=initRefs,
199 inputs={inputDSType: inputRefs},
200 outputs={outputDSType: outputRefs}
201 )
202 )
203 quantumMap[taskDef] = quantumSet
204 self.tasks = tasks
205 self.quantumMap = quantumMap
206 self.qGraph = QuantumGraph(quantumMap, metadata=METADATA)
207 self.universe = universe
209 def _cleanGraphs(self, graph1, graph2):
210 # This is a hack for the unit test since the qualified name will be
211 # different as it will be __main__ here, but qualified to the
212 # unittest module name when restored
213 # Updates in place
214 for saved, loaded in zip(graph1._quanta.keys(),
215 graph2._quanta.keys()):
216 saved.taskName = saved.taskName.split('.')[-1]
217 loaded.taskName = loaded.taskName.split('.')[-1]
219 def testTaskGraph(self):
220 for taskDef in self.quantumMap.keys():
221 self.assertIn(taskDef, self.qGraph.taskGraph)
223 def testGraph(self):
224 graphSet = {q.quantum for q in self.qGraph.graph}
225 for quantum in chain.from_iterable(self.quantumMap.values()):
226 self.assertIn(quantum, graphSet)
228 def testGetQuantumNodeByNodeId(self):
229 inputQuanta = tuple(self.qGraph.inputQuanta)
230 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
231 self.assertEqual(node, inputQuanta[0])
232 wrongNode = NodeId(15, BuildId("alternative build Id"))
233 with self.assertRaises(IncompatibleGraphError):
234 self.qGraph.getQuantumNodeByNodeId(wrongNode)
236 def testPickle(self):
237 stringify = pickle.dumps(self.qGraph)
238 restore: QuantumGraph = pickle.loads(stringify)
239 self._cleanGraphs(self.qGraph, restore)
240 self.assertEqual(self.qGraph, restore)
242 def testInputQuanta(self):
243 inputs = {q.quantum for q in self.qGraph.inputQuanta}
244 self.assertEqual(self.quantumMap[self.tasks[0]], inputs)
246 def testOutputtQuanta(self):
247 outputs = {q.quantum for q in self.qGraph.outputQuanta}
248 self.assertEqual(self.quantumMap[self.tasks[-1]], outputs)
250 def testLength(self):
251 self.assertEqual(len(self.qGraph), 6)
253 def testGetQuantaForTask(self):
254 for task in self.tasks:
255 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
257 def testGetNodesForTask(self):
258 for task in self.tasks:
259 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task)
260 quanta_in_node = set(n.quantum for n in nodes)
261 self.assertEqual(quanta_in_node, self.quantumMap[task])
263 def testFindTasksWithInput(self):
264 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0],
265 self.tasks[1])
267 def testFindTasksWithOutput(self):
268 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
270 def testTaskWithDSType(self):
271 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))),
272 set(self.tasks[:2]))
274 def testFindTaskDefByName(self):
275 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0],
276 self.tasks[0])
278 def testFindTaskDefByLabel(self):
279 self.assertEqual(self.qGraph.findTaskDefByLabel("R"),
280 self.tasks[0])
282 def testFindQuantaWIthDSType(self):
283 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")),
284 self.quantumMap[self.tasks[0]])
286 def testAllDatasetTypes(self):
287 allDatasetTypes = set(self.qGraph.allDatasetTypes)
288 truth = set()
289 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections):
290 for connection in conClass.allConnections.values(): # type: ignore
291 truth.add(connection.name)
292 self.assertEqual(allDatasetTypes, truth)
294 def testSubset(self):
295 allNodes = list(self.qGraph)
296 subset = self.qGraph.subset(allNodes[0])
297 self.assertEqual(len(subset), 1)
298 subsetList = list(subset)
299 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
300 self.assertEqual(self.qGraph._buildId, subset._buildId)
302 def testIsConnected(self):
303 # False because there are two quantum chains for two distinct sets of
304 # dimensions
305 self.assertFalse(self.qGraph.isConnected)
306 # make a broken subset
307 allNodes = list(self.qGraph)
308 subset = self.qGraph.subset((allNodes[0], allNodes[1]))
309 # True because we subset to only one chain of graphs
310 self.assertTrue(subset.isConnected)
312 def testSubsetToConnected(self):
313 connectedGraphs = self.qGraph.subsetToConnected()
314 self.assertEqual(len(connectedGraphs), 2)
315 self.assertTrue(connectedGraphs[0].isConnected)
316 self.assertTrue(connectedGraphs[1].isConnected)
318 self.assertEqual(len(connectedGraphs[0]), 3)
319 self.assertEqual(len(connectedGraphs[1]), 3)
321 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
323 count = 0
324 for node in self.qGraph:
325 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
326 count += 1
327 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
328 count += 1
329 self.assertEqual(len(self.qGraph), count)
331 self.assertEqual(set(self.tasks), set(connectedGraphs[0].taskGraph))
332 self.assertEqual(set(self.tasks), set(connectedGraphs[1].taskGraph))
333 allNodes = list(self.qGraph)
334 node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
335 self.assertEqual(set([allNodes[0]]), node)
336 node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
337 self.assertEqual(set([allNodes[0]]), node)
339 def testDetermineOutputsOfQuantumNode(self):
340 allNodes = list(self.qGraph)
341 node = next(iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1])))
342 self.assertEqual(allNodes[2], node)
344 def testDetermineConnectionsOfQuantum(self):
345 allNodes = list(self.qGraph)
346 connections = self.qGraph.determineConnectionsOfQuantumNode(allNodes[1])
347 self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3])))
349 def testDetermineAnsestorsOfQuantumNode(self):
350 allNodes = list(self.qGraph)
351 ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2])
352 self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3])))
354 def testFindCycle(self):
355 self.assertFalse(self.qGraph.findCycle())
357 def testSaveLoad(self):
358 with tempfile.TemporaryFile(suffix='.qgraph') as tmpFile:
359 self.qGraph.save(tmpFile)
360 tmpFile.seek(0)
361 restore = QuantumGraph.load(tmpFile, self.universe)
362 self._cleanGraphs(self.qGraph, restore)
363 self.assertEqual(self.qGraph, restore)
364 # Load in just one node
365 tmpFile.seek(0)
366 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(0,))
367 self.assertEqual(len(restoreSub), 1)
368 self.assertEqual(list(restoreSub)[0],
369 restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId)))
371 def testSaveLoadUri(self):
372 uri = None
373 try:
374 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile:
375 uri = tmpFile.name
376 self.qGraph.saveUri(uri)
377 restore = QuantumGraph.loadUri(uri, self.universe)
378 self.assertEqual(restore.metadata, METADATA)
379 self._cleanGraphs(self.qGraph, restore)
380 self.assertEqual(self.qGraph, restore)
381 nodeNumber = random.randint(0, len(self.qGraph)-1)
382 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber,),
383 graphID=self.qGraph._buildId)
384 self.assertEqual(len(restoreSub), 1)
385 self.assertEqual(list(restoreSub)[0],
386 restore.getQuantumNodeByNodeId(NodeId(nodeNumber, restore.graphID)))
387 # verify that more than one node works
388 nodeNumber2 = random.randint(0, len(self.qGraph)-1)
389 # ensure it is a different node number
390 while nodeNumber2 == nodeNumber:
391 nodeNumber2 = random.randint(0, len(self.qGraph)-1)
392 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2))
393 self.assertEqual(len(restoreSub), 2)
394 self.assertEqual(set(restoreSub),
395 set((restore.getQuantumNodeByNodeId(NodeId(nodeNumber, restore._buildId)),
396 restore.getQuantumNodeByNodeId(NodeId(nodeNumber2, restore._buildId)))))
397 # verify an error when requesting a non existant node number
398 with self.assertRaises(ValueError):
399 QuantumGraph.loadUri(uri, self.universe, nodes=(99,))
401 # verify a graphID that does not match will be an error
402 with self.assertRaises(ValueError):
403 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT")
405 except Exception as e:
406 raise e
407 finally:
408 if uri is not None:
409 os.remove(uri)
411 with self.assertRaises(TypeError):
412 self.qGraph.saveUri("test.notgraph")
414 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!")
415 @mock_s3
416 def testSaveLoadUriS3(self):
417 # Test loading a quantum graph from an mock s3 store
418 conn = boto3.resource('s3', region_name="us-east-1")
419 conn.create_bucket(Bucket='testBucket')
420 uri = "s3://testBucket/qgraph.qgraph"
421 self.qGraph.saveUri(uri)
422 restore = QuantumGraph.loadUri(uri, self.universe)
423 self._cleanGraphs(self.qGraph, restore)
424 self.assertEqual(self.qGraph, restore)
425 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(0,))
426 self.assertEqual(len(restoreSub), 1)
427 self.assertEqual(list(restoreSub)[0],
428 restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId)))
430 def testContains(self):
431 firstNode = next(iter(self.qGraph))
432 self.assertIn(firstNode, self.qGraph)
435class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
436 pass
439def setup_module(module):
440 lsst.utils.tests.init()
443if __name__ == "__main__": 443 ↛ 444line 443 didn't jump to line 444, because the condition on line 443 was never true
444 lsst.utils.tests.init()
445 unittest.main()