Coverage for tests/test_quantumGraph.py: 30%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import os
23import pickle
24import random
25import tempfile
26import unittest
27import uuid
28from itertools import chain
29from typing import Iterable
31import lsst.pipe.base.connectionTypes as cT
32import lsst.utils.tests
33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum
34from lsst.pex.config import Field
35from lsst.pipe.base import (
36 DatasetTypeName,
37 PipelineTask,
38 PipelineTaskConfig,
39 PipelineTaskConnections,
40 QuantumGraph,
41 TaskDef,
42)
43from lsst.pipe.base.graph.quantumNode import QuantumNode
45try:
46 import boto3
47 from moto import mock_s3
48except ImportError:
49 boto3 = None
51 def mock_s3(cls):
52 """A no-op decorator in case moto mock_s3 can not be imported."""
53 return cls
56METADATA = {"a": [1, 2, 3]}
59class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
60 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
61 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
62 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
65class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
66 conf1 = Field(dtype=int, default=1, doc="dummy config")
69class Dummy1PipelineTask(PipelineTask):
70 ConfigClass = Dummy1Config
73class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
74 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
75 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
76 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
77 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
80class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
81 conf1 = Field(dtype=int, default=1, doc="dummy config")
84class Dummy2PipelineTask(PipelineTask):
85 ConfigClass = Dummy2Config
88class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
89 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
90 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a")
91 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
92 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
95class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
96 conf1 = Field(dtype=int, default=1, doc="dummy config")
99class Dummy3PipelineTask(PipelineTask):
100 ConfigClass = Dummy3Config
103# Test if a Task that does not interact with the other Tasks works fine in
104# the graph.
105class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")):
106 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
107 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
110class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections):
111 conf1 = Field(dtype=int, default=1, doc="dummy config")
114class Dummy4PipelineTask(PipelineTask):
115 ConfigClass = Dummy4Config
118class QuantumGraphTestCase(unittest.TestCase):
119 """Tests the various functions of a quantum graph"""
121 def setUp(self):
122 config = Config(
123 {
124 "version": 1,
125 "skypix": {
126 "common": "htm7",
127 "htm": {
128 "class": "lsst.sphgeom.HtmPixelization",
129 "max_level": 24,
130 },
131 },
132 "elements": {
133 "A": {
134 "keys": [
135 {
136 "name": "id",
137 "type": "int",
138 }
139 ],
140 "storage": {
141 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
142 },
143 },
144 "B": {
145 "keys": [
146 {
147 "name": "id",
148 "type": "int",
149 }
150 ],
151 "storage": {
152 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
153 },
154 },
155 },
156 "packers": {},
157 }
158 )
159 universe = DimensionUniverse(config=config)
160 # need to make a mapping of TaskDef to set of quantum
161 quantumMap = {}
162 tasks = []
163 for task, label in (
164 (Dummy1PipelineTask, "R"),
165 (Dummy2PipelineTask, "S"),
166 (Dummy3PipelineTask, "T"),
167 (Dummy4PipelineTask, "U"),
168 ):
169 config = task.ConfigClass()
170 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label)
171 tasks.append(taskDef)
172 quantumSet = set()
173 connections = taskDef.connections
174 for a, b in ((1, 2), (3, 4)):
175 if connections.initInputs:
176 initInputDSType = DatasetType(
177 connections.initInput.name,
178 tuple(),
179 storageClass=connections.initInput.storageClass,
180 universe=universe,
181 )
182 initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))]
183 else:
184 initRefs = None
185 inputDSType = DatasetType(
186 connections.input.name,
187 connections.input.dimensions,
188 storageClass=connections.input.storageClass,
189 universe=universe,
190 )
191 inputRefs = [
192 DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
193 ]
194 outputDSType = DatasetType(
195 connections.output.name,
196 connections.output.dimensions,
197 storageClass=connections.output.storageClass,
198 universe=universe,
199 )
200 outputRefs = [
201 DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
202 ]
203 quantumSet.add(
204 Quantum(
205 taskName=task.__qualname__,
206 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
207 taskClass=task,
208 initInputs=initRefs,
209 inputs={inputDSType: inputRefs},
210 outputs={outputDSType: outputRefs},
211 )
212 )
213 quantumMap[taskDef] = quantumSet
214 self.tasks = tasks
215 self.quantumMap = quantumMap
216 self.qGraph = QuantumGraph(quantumMap, metadata=METADATA)
217 self.universe = universe
219 def _cleanGraphs(self, graph1, graph2):
220 # This is a hack for the unit test since the qualified name will be
221 # different as it will be __main__ here, but qualified to the
222 # unittest module name when restored
223 # Updates in place
224 for saved, loaded in zip(graph1.taskGraph, graph2.taskGraph):
225 saved.taskName = saved.taskName.split(".")[-1]
226 loaded.taskName = loaded.taskName.split(".")[-1]
228 def testTaskGraph(self):
229 for taskDef in self.quantumMap.keys():
230 self.assertIn(taskDef, self.qGraph.taskGraph)
232 def testGraph(self):
233 graphSet = {q.quantum for q in self.qGraph.graph}
234 for quantum in chain.from_iterable(self.quantumMap.values()):
235 self.assertIn(quantum, graphSet)
237 def testGetQuantumNodeByNodeId(self):
238 inputQuanta = tuple(self.qGraph.inputQuanta)
239 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
240 self.assertEqual(node, inputQuanta[0])
241 wrongNode = uuid.uuid4()
242 with self.assertRaises(KeyError):
243 self.qGraph.getQuantumNodeByNodeId(wrongNode)
245 def testPickle(self):
246 stringify = pickle.dumps(self.qGraph)
247 restore: QuantumGraph = pickle.loads(stringify)
248 self._cleanGraphs(self.qGraph, restore)
249 self.assertEqual(self.qGraph, restore)
251 def testInputQuanta(self):
252 inputs = {q.quantum for q in self.qGraph.inputQuanta}
253 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs)
255 def testOutputtQuanta(self):
256 outputs = {q.quantum for q in self.qGraph.outputQuanta}
257 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs)
259 def testLength(self):
260 self.assertEqual(len(self.qGraph), 2 * len(self.tasks))
262 def testGetQuantaForTask(self):
263 for task in self.tasks:
264 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
266 def testGetNodesForTask(self):
267 for task in self.tasks:
268 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task)
269 quanta_in_node = set(n.quantum for n in nodes)
270 self.assertEqual(quanta_in_node, self.quantumMap[task])
272 def testFindTasksWithInput(self):
273 self.assertEqual(
274 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1]
275 )
277 def testFindTasksWithOutput(self):
278 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
280 def testTaskWithDSType(self):
281 self.assertEqual(
282 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2])
283 )
285 def testFindTaskDefByName(self):
286 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0])
288 def testFindTaskDefByLabel(self):
289 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0])
291 def testFindQuantaWIthDSType(self):
292 self.assertEqual(
293 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]]
294 )
296 def testAllDatasetTypes(self):
297 allDatasetTypes = set(self.qGraph.allDatasetTypes)
298 truth = set()
299 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections):
300 for connection in conClass.allConnections.values(): # type: ignore
301 if not isinstance(connection, cT.InitOutput):
302 truth.add(connection.name)
303 self.assertEqual(allDatasetTypes, truth)
305 def testSubset(self):
306 allNodes = list(self.qGraph)
307 subset = self.qGraph.subset(allNodes[0])
308 self.assertEqual(len(subset), 1)
309 subsetList = list(subset)
310 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
311 self.assertEqual(self.qGraph._buildId, subset._buildId)
313 def testIsConnected(self):
314 # False because there are two quantum chains for two distinct sets of
315 # dimensions
316 self.assertFalse(self.qGraph.isConnected)
317 # make a broken subset
318 filteredNodes = [n for n in self.qGraph if n.taskDef.label != "U"]
319 subset = self.qGraph.subset((filteredNodes[0], filteredNodes[1]))
320 # True because we subset to only one chain of graphs
321 self.assertTrue(subset.isConnected)
323 def testSubsetToConnected(self):
324 connectedGraphs = self.qGraph.subsetToConnected()
325 self.assertEqual(len(connectedGraphs), 4)
326 self.assertTrue(connectedGraphs[0].isConnected)
327 self.assertTrue(connectedGraphs[1].isConnected)
328 self.assertTrue(connectedGraphs[2].isConnected)
329 self.assertTrue(connectedGraphs[3].isConnected)
331 # Split out task[3] because it is expected to be on its own
332 for cg in connectedGraphs:
333 if self.tasks[3] in cg.taskGraph:
334 self.assertEqual(len(cg), 1)
335 else:
336 self.assertEqual(len(cg), 3)
338 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
340 count = 0
341 for node in self.qGraph:
342 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
343 count += 1
344 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
345 count += 1
346 if connectedGraphs[2].checkQuantumInGraph(node.quantum):
347 count += 1
348 if connectedGraphs[3].checkQuantumInGraph(node.quantum):
349 count += 1
350 self.assertEqual(len(self.qGraph), count)
352 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs}
353 for setLen, tskSet in taskSets.items():
354 if setLen == 3:
355 self.assertEqual(set(self.tasks[:-1]), tskSet)
356 elif setLen == 1:
357 self.assertEqual({self.tasks[-1]}, tskSet)
358 for cg in connectedGraphs:
359 if len(cg.taskGraph) == 1:
360 continue
361 allNodes = list(cg)
362 node = cg.determineInputsToQuantumNode(allNodes[1])
363 self.assertEqual(set([allNodes[0]]), node)
364 node = cg.determineInputsToQuantumNode(allNodes[1])
365 self.assertEqual(set([allNodes[0]]), node)
367 def testDetermineOutputsOfQuantumNode(self):
368 testNodes = self.qGraph.getNodesForTask(self.tasks[0])
369 matchNodes = self.qGraph.getNodesForTask(self.tasks[1])
370 connections = set()
371 for node in testNodes:
372 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node))
373 self.assertEqual(matchNodes, connections)
375 def testDetermineConnectionsOfQuantum(self):
376 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
377 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2])
378 # outputs contain nodes tested for because it is a complete graph
379 matchNodes |= set(testNodes)
380 connections = set()
381 for node in testNodes:
382 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node))
383 self.assertEqual(matchNodes, connections)
385 def testDetermineAnsestorsOfQuantumNode(self):
386 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
387 matchNodes = self.qGraph.getNodesForTask(self.tasks[0])
388 matchNodes |= set(testNodes)
389 connections = set()
390 for node in testNodes:
391 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node))
392 self.assertEqual(matchNodes, connections)
394 def testFindCycle(self):
395 self.assertFalse(self.qGraph.findCycle())
397 def testSaveLoad(self):
398 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile:
399 self.qGraph.save(tmpFile)
400 tmpFile.seek(0)
401 restore = QuantumGraph.load(tmpFile, self.universe)
402 self._cleanGraphs(self.qGraph, restore)
403 self.assertEqual(self.qGraph, restore)
404 # Load in just one node
405 tmpFile.seek(0)
406 nodeId = [n.nodeId for n in self.qGraph][0]
407 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,))
408 self.assertEqual(len(restoreSub), 1)
409 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
411 def testSaveLoadUri(self):
412 uri = None
413 try:
414 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile:
415 uri = tmpFile.name
416 self.qGraph.saveUri(uri)
417 restore = QuantumGraph.loadUri(uri, self.universe)
418 self.assertEqual(restore.metadata, METADATA)
419 self._cleanGraphs(self.qGraph, restore)
420 self.assertEqual(self.qGraph, restore)
421 nodeNumberId = random.randint(0, len(self.qGraph) - 1)
422 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId]
423 restoreSub = QuantumGraph.loadUri(
424 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId
425 )
426 self.assertEqual(len(restoreSub), 1)
427 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber))
428 # verify that more than one node works
429 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
430 # ensure it is a different node number
431 while nodeNumberId2 == nodeNumberId:
432 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
433 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2]
434 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2))
435 self.assertEqual(len(restoreSub), 2)
436 self.assertEqual(
437 set(restoreSub),
438 set(
439 (
440 restore.getQuantumNodeByNodeId(nodeNumber),
441 restore.getQuantumNodeByNodeId(nodeNumber2),
442 )
443 ),
444 )
445 # verify an error when requesting a non existant node number
446 with self.assertRaises(ValueError):
447 QuantumGraph.loadUri(uri, self.universe, nodes=(99,))
449 # verify a graphID that does not match will be an error
450 with self.assertRaises(ValueError):
451 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT")
453 except Exception as e:
454 raise e
455 finally:
456 if uri is not None:
457 os.remove(uri)
459 with self.assertRaises(TypeError):
460 self.qGraph.saveUri("test.notgraph")
462 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!")
463 @mock_s3
464 def testSaveLoadUriS3(self):
465 # Test loading a quantum graph from an mock s3 store
466 conn = boto3.resource("s3", region_name="us-east-1")
467 conn.create_bucket(Bucket="testBucket")
468 uri = "s3://testBucket/qgraph.qgraph"
469 self.qGraph.saveUri(uri)
470 restore = QuantumGraph.loadUri(uri, self.universe)
471 self._cleanGraphs(self.qGraph, restore)
472 self.assertEqual(self.qGraph, restore)
473 nodeId = list(self.qGraph)[0].nodeId
474 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeId,))
475 self.assertEqual(len(restoreSub), 1)
476 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
478 def testContains(self):
479 firstNode = next(iter(self.qGraph))
480 self.assertIn(firstNode, self.qGraph)
483class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
484 pass
487def setup_module(module):
488 lsst.utils.tests.init()
491if __name__ == "__main__": 491 ↛ 492line 491 didn't jump to line 492, because the condition on line 491 was never true
492 lsst.utils.tests.init()
493 unittest.main()