Coverage for tests/test_quantumGraph.py: 30%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from itertools import chain
23import os
24import pickle
25import tempfile
26from typing import Iterable
27import unittest
28import uuid
29import random
30from lsst.daf.butler import DimensionUniverse
32from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections,
33 DatasetTypeName)
34import lsst.pipe.base.connectionTypes as cT
35from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config
36from lsst.pex.config import Field
37from lsst.pipe.base.graph.quantumNode import QuantumNode
38import lsst.utils.tests
40try:
41 import boto3
42 from moto import mock_s3
43except ImportError:
44 boto3 = None
46 def mock_s3(cls):
47 """A no-op decorator in case moto mock_s3 can not be imported.
48 """
49 return cls
51METADATA = {'a': [1, 2, 3]}
54class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
55 initOutput = cT.InitOutput(name="Dummy1InitOutput",
56 storageClass="ExposureF",
57 doc="n/a")
58 input = cT.Input(name="Dummy1Input",
59 storageClass="ExposureF",
60 doc="n/a",
61 dimensions=("A", "B"))
62 output = cT.Output(name="Dummy1Output",
63 storageClass="ExposureF",
64 doc="n/a",
65 dimensions=("A", "B"))
68class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
69 conf1 = Field(dtype=int, default=1, doc="dummy config")
72class Dummy1PipelineTask(PipelineTask):
73 ConfigClass = Dummy1Config
76class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
77 initInput = cT.InitInput(name="Dummy1InitOutput",
78 storageClass="ExposureF",
79 doc="n/a")
80 initOutput = cT.InitOutput(name="Dummy2InitOutput",
81 storageClass="ExposureF",
82 doc="n/a")
83 input = cT.Input(name="Dummy1Output",
84 storageClass="ExposureF",
85 doc="n/a",
86 dimensions=("A", "B"))
87 output = cT.Output(name="Dummy2Output",
88 storageClass="ExposureF",
89 doc="n/a",
90 dimensions=("A", "B"))
93class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
94 conf1 = Field(dtype=int, default=1, doc="dummy config")
97class Dummy2PipelineTask(PipelineTask):
98 ConfigClass = Dummy2Config
101class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
102 initInput = cT.InitInput(name="Dummy2InitOutput",
103 storageClass="ExposureF",
104 doc="n/a")
105 initOutput = cT.InitOutput(name="Dummy3InitOutput",
106 storageClass="ExposureF",
107 doc="n/a")
108 input = cT.Input(name="Dummy2Output",
109 storageClass="ExposureF",
110 doc="n/a",
111 dimensions=("A", "B"))
112 output = cT.Output(name="Dummy3Output",
113 storageClass="ExposureF",
114 doc="n/a",
115 dimensions=("A", "B"))
118class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
119 conf1 = Field(dtype=int, default=1, doc="dummy config")
122class Dummy3PipelineTask(PipelineTask):
123 ConfigClass = Dummy3Config
126# Test if a Task that does not interact with the other Tasks works fine in
127# the graph.
128class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")):
129 input = cT.Input(name="Dummy4Input",
130 storageClass="ExposureF",
131 doc="n/a",
132 dimensions=("A", "B"))
133 output = cT.Output(name="Dummy4Output",
134 storageClass="ExposureF",
135 doc="n/a",
136 dimensions=("A", "B"))
139class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections):
140 conf1 = Field(dtype=int, default=1, doc="dummy config")
143class Dummy4PipelineTask(PipelineTask):
144 ConfigClass = Dummy4Config
147class QuantumGraphTestCase(unittest.TestCase):
148 """Tests the various functions of a quantum graph
149 """
150 def setUp(self):
151 config = Config({
152 "version": 1,
153 "skypix": {
154 "common": "htm7",
155 "htm": {
156 "class": "lsst.sphgeom.HtmPixelization",
157 "max_level": 24,
158 }
159 },
160 "elements": {
161 "A": {
162 "keys": [{
163 "name": "id",
164 "type": "int",
165 }],
166 "storage": {
167 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
168 },
169 },
170 "B": {
171 "keys": [{
172 "name": "id",
173 "type": "int",
174 }],
175 "storage": {
176 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
177 },
178 }
179 },
180 "packers": {}
181 })
182 universe = DimensionUniverse(config=config)
183 # need to make a mapping of TaskDef to set of quantum
184 quantumMap = {}
185 tasks = []
186 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T"),
187 (Dummy4PipelineTask, "U")):
188 config = task.ConfigClass()
189 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label)
190 tasks.append(taskDef)
191 quantumSet = set()
192 connections = taskDef.connections
193 for a, b in ((1, 2), (3, 4)):
194 if connections.initInputs:
195 initInputDSType = DatasetType(connections.initInput.name,
196 tuple(),
197 storageClass=connections.initInput.storageClass,
198 universe=universe)
199 initRefs = [DatasetRef(initInputDSType,
200 DataCoordinate.makeEmpty(universe))]
201 else:
202 initRefs = None
203 inputDSType = DatasetType(connections.input.name,
204 connections.input.dimensions,
205 storageClass=connections.input.storageClass,
206 universe=universe,
207 )
208 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b},
209 universe=universe))]
210 outputDSType = DatasetType(connections.output.name,
211 connections.output.dimensions,
212 storageClass=connections.output.storageClass,
213 universe=universe,
214 )
215 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b},
216 universe=universe))]
217 quantumSet.add(
218 Quantum(taskName=task.__qualname__,
219 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
220 taskClass=task,
221 initInputs=initRefs,
222 inputs={inputDSType: inputRefs},
223 outputs={outputDSType: outputRefs}
224 )
225 )
226 quantumMap[taskDef] = quantumSet
227 self.tasks = tasks
228 self.quantumMap = quantumMap
229 self.qGraph = QuantumGraph(quantumMap, metadata=METADATA)
230 self.universe = universe
232 def _cleanGraphs(self, graph1, graph2):
233 # This is a hack for the unit test since the qualified name will be
234 # different as it will be __main__ here, but qualified to the
235 # unittest module name when restored
236 # Updates in place
237 for saved, loaded in zip(graph1.taskGraph,
238 graph2.taskGraph):
239 saved.taskName = saved.taskName.split('.')[-1]
240 loaded.taskName = loaded.taskName.split('.')[-1]
242 def testTaskGraph(self):
243 for taskDef in self.quantumMap.keys():
244 self.assertIn(taskDef, self.qGraph.taskGraph)
246 def testGraph(self):
247 graphSet = {q.quantum for q in self.qGraph.graph}
248 for quantum in chain.from_iterable(self.quantumMap.values()):
249 self.assertIn(quantum, graphSet)
251 def testGetQuantumNodeByNodeId(self):
252 inputQuanta = tuple(self.qGraph.inputQuanta)
253 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
254 self.assertEqual(node, inputQuanta[0])
255 wrongNode = uuid.uuid4()
256 with self.assertRaises(KeyError):
257 self.qGraph.getQuantumNodeByNodeId(wrongNode)
259 def testPickle(self):
260 stringify = pickle.dumps(self.qGraph)
261 restore: QuantumGraph = pickle.loads(stringify)
262 self._cleanGraphs(self.qGraph, restore)
263 self.assertEqual(self.qGraph, restore)
265 def testInputQuanta(self):
266 inputs = {q.quantum for q in self.qGraph.inputQuanta}
267 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs)
269 def testOutputtQuanta(self):
270 outputs = {q.quantum for q in self.qGraph.outputQuanta}
271 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs)
273 def testLength(self):
274 self.assertEqual(len(self.qGraph), 2*len(self.tasks))
276 def testGetQuantaForTask(self):
277 for task in self.tasks:
278 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
280 def testGetNodesForTask(self):
281 for task in self.tasks:
282 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task)
283 quanta_in_node = set(n.quantum for n in nodes)
284 self.assertEqual(quanta_in_node, self.quantumMap[task])
286 def testFindTasksWithInput(self):
287 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0],
288 self.tasks[1])
290 def testFindTasksWithOutput(self):
291 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
293 def testTaskWithDSType(self):
294 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))),
295 set(self.tasks[:2]))
297 def testFindTaskDefByName(self):
298 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0],
299 self.tasks[0])
301 def testFindTaskDefByLabel(self):
302 self.assertEqual(self.qGraph.findTaskDefByLabel("R"),
303 self.tasks[0])
305 def testFindQuantaWIthDSType(self):
306 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")),
307 self.quantumMap[self.tasks[0]])
309 def testAllDatasetTypes(self):
310 allDatasetTypes = set(self.qGraph.allDatasetTypes)
311 truth = set()
312 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections):
313 for connection in conClass.allConnections.values(): # type: ignore
314 if not isinstance(connection, cT.InitOutput):
315 truth.add(connection.name)
316 self.assertEqual(allDatasetTypes, truth)
318 def testSubset(self):
319 allNodes = list(self.qGraph)
320 subset = self.qGraph.subset(allNodes[0])
321 self.assertEqual(len(subset), 1)
322 subsetList = list(subset)
323 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
324 self.assertEqual(self.qGraph._buildId, subset._buildId)
326 def testIsConnected(self):
327 # False because there are two quantum chains for two distinct sets of
328 # dimensions
329 self.assertFalse(self.qGraph.isConnected)
330 # make a broken subset
331 filteredNodes = [n for n in self.qGraph if n.taskDef.label != 'U']
332 subset = self.qGraph.subset((filteredNodes[0], filteredNodes[1]))
333 # True because we subset to only one chain of graphs
334 self.assertTrue(subset.isConnected)
336 def testSubsetToConnected(self):
337 connectedGraphs = self.qGraph.subsetToConnected()
338 self.assertEqual(len(connectedGraphs), 4)
339 self.assertTrue(connectedGraphs[0].isConnected)
340 self.assertTrue(connectedGraphs[1].isConnected)
341 self.assertTrue(connectedGraphs[2].isConnected)
342 self.assertTrue(connectedGraphs[3].isConnected)
344 # Split out task[3] because it is expected to be on its own
345 for cg in connectedGraphs:
346 if self.tasks[3] in cg.taskGraph:
347 self.assertEqual(len(cg), 1)
348 else:
349 self.assertEqual(len(cg), 3)
351 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
353 count = 0
354 for node in self.qGraph:
355 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
356 count += 1
357 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
358 count += 1
359 if connectedGraphs[2].checkQuantumInGraph(node.quantum):
360 count += 1
361 if connectedGraphs[3].checkQuantumInGraph(node.quantum):
362 count += 1
363 self.assertEqual(len(self.qGraph), count)
365 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs}
366 for setLen, tskSet in taskSets.items():
367 if setLen == 3:
368 self.assertEqual(set(self.tasks[:-1]), tskSet)
369 elif setLen == 1:
370 self.assertEqual({self.tasks[-1]}, tskSet)
371 for cg in connectedGraphs:
372 if len(cg.taskGraph) == 1:
373 continue
374 allNodes = list(cg)
375 node = cg.determineInputsToQuantumNode(allNodes[1])
376 self.assertEqual(set([allNodes[0]]), node)
377 node = cg.determineInputsToQuantumNode(allNodes[1])
378 self.assertEqual(set([allNodes[0]]), node)
380 def testDetermineOutputsOfQuantumNode(self):
381 testNodes = self.qGraph.getNodesForTask(self.tasks[0])
382 matchNodes = self.qGraph.getNodesForTask(self.tasks[1])
383 connections = set()
384 for node in testNodes:
385 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node))
386 self.assertEqual(matchNodes, connections)
388 def testDetermineConnectionsOfQuantum(self):
389 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
390 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2])
391 # outputs contain nodes tested for because it is a complete graph
392 matchNodes |= set(testNodes)
393 connections = set()
394 for node in testNodes:
395 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node))
396 self.assertEqual(matchNodes, connections)
398 def testDetermineAnsestorsOfQuantumNode(self):
399 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
400 matchNodes = self.qGraph.getNodesForTask(self.tasks[0])
401 matchNodes |= set(testNodes)
402 connections = set()
403 for node in testNodes:
404 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node))
405 self.assertEqual(matchNodes, connections)
407 def testFindCycle(self):
408 self.assertFalse(self.qGraph.findCycle())
410 def testSaveLoad(self):
411 with tempfile.TemporaryFile(suffix='.qgraph') as tmpFile:
412 self.qGraph.save(tmpFile)
413 tmpFile.seek(0)
414 restore = QuantumGraph.load(tmpFile, self.universe)
415 self._cleanGraphs(self.qGraph, restore)
416 self.assertEqual(self.qGraph, restore)
417 # Load in just one node
418 tmpFile.seek(0)
419 nodeId = [n.nodeId for n in self.qGraph][0]
420 restoreSub = QuantumGraph.load(tmpFile, self.universe,
421 nodes=(nodeId,))
422 self.assertEqual(len(restoreSub), 1)
423 self.assertEqual(list(restoreSub)[0],
424 restore.getQuantumNodeByNodeId(nodeId))
426 def testSaveLoadUri(self):
427 uri = None
428 try:
429 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile:
430 uri = tmpFile.name
431 self.qGraph.saveUri(uri)
432 restore = QuantumGraph.loadUri(uri, self.universe)
433 self.assertEqual(restore.metadata, METADATA)
434 self._cleanGraphs(self.qGraph, restore)
435 self.assertEqual(self.qGraph, restore)
436 nodeNumberId = random.randint(0, len(self.qGraph)-1)
437 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId]
438 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber,),
439 graphID=self.qGraph._buildId)
440 self.assertEqual(len(restoreSub), 1)
441 self.assertEqual(list(restoreSub)[0],
442 restore.getQuantumNodeByNodeId(nodeNumber))
443 # verify that more than one node works
444 nodeNumberId2 = random.randint(0, len(self.qGraph)-1)
445 # ensure it is a different node number
446 while nodeNumberId2 == nodeNumberId:
447 nodeNumberId2 = random.randint(0, len(self.qGraph)-1)
448 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2]
449 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2))
450 self.assertEqual(len(restoreSub), 2)
451 self.assertEqual(set(restoreSub),
452 set((restore.getQuantumNodeByNodeId(nodeNumber),
453 restore.getQuantumNodeByNodeId(nodeNumber2))))
454 # verify an error when requesting a non existant node number
455 with self.assertRaises(ValueError):
456 QuantumGraph.loadUri(uri, self.universe, nodes=(99,))
458 # verify a graphID that does not match will be an error
459 with self.assertRaises(ValueError):
460 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT")
462 except Exception as e:
463 raise e
464 finally:
465 if uri is not None:
466 os.remove(uri)
468 with self.assertRaises(TypeError):
469 self.qGraph.saveUri("test.notgraph")
471 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!")
472 @mock_s3
473 def testSaveLoadUriS3(self):
474 # Test loading a quantum graph from an mock s3 store
475 conn = boto3.resource('s3', region_name="us-east-1")
476 conn.create_bucket(Bucket='testBucket')
477 uri = "s3://testBucket/qgraph.qgraph"
478 self.qGraph.saveUri(uri)
479 restore = QuantumGraph.loadUri(uri, self.universe)
480 self._cleanGraphs(self.qGraph, restore)
481 self.assertEqual(self.qGraph, restore)
482 nodeId = list(self.qGraph)[0].nodeId
483 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeId,))
484 self.assertEqual(len(restoreSub), 1)
485 self.assertEqual(list(restoreSub)[0],
486 restore.getQuantumNodeByNodeId(nodeId))
488 def testContains(self):
489 firstNode = next(iter(self.qGraph))
490 self.assertIn(firstNode, self.qGraph)
493class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
494 pass
497def setup_module(module):
498 lsst.utils.tests.init()
501if __name__ == "__main__": 501 ↛ 502line 501 didn't jump to line 502, because the condition on line 501 was never true
502 lsst.utils.tests.init()
503 unittest.main()