Coverage for tests/test_quantumGraph.py: 31%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import os
23import pickle
24import random
25import tempfile
26import unittest
27import uuid
28from itertools import chain
29from typing import Iterable
31import lsst.pipe.base.connectionTypes as cT
32import lsst.utils.tests
33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum
34from lsst.pex.config import Field
35from lsst.pipe.base import (
36 DatasetTypeName,
37 PipelineTask,
38 PipelineTaskConfig,
39 PipelineTaskConnections,
40 QuantumGraph,
41 TaskDef,
42)
43from lsst.pipe.base.graph.quantumNode import QuantumNode
44from lsst.utils.introspection import get_full_type_name
46try:
47 import boto3
48 from moto import mock_s3
49except ImportError:
50 boto3 = None
52 def mock_s3(cls):
53 """A no-op decorator in case moto mock_s3 can not be imported."""
54 return cls
57METADATA = {"a": [1, 2, 3]}
60class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
61 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
62 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
63 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
66class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
67 conf1 = Field(dtype=int, default=1, doc="dummy config")
70class Dummy1PipelineTask(PipelineTask):
71 ConfigClass = Dummy1Config
74class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
75 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
76 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
77 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
78 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
81class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
82 conf1 = Field(dtype=int, default=1, doc="dummy config")
85class Dummy2PipelineTask(PipelineTask):
86 ConfigClass = Dummy2Config
89class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
90 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
91 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a")
92 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
93 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
96class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
97 conf1 = Field(dtype=int, default=1, doc="dummy config")
100class Dummy3PipelineTask(PipelineTask):
101 ConfigClass = Dummy3Config
104# Test if a Task that does not interact with the other Tasks works fine in
105# the graph.
106class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")):
107 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
108 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
111class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections):
112 conf1 = Field(dtype=int, default=1, doc="dummy config")
115class Dummy4PipelineTask(PipelineTask):
116 ConfigClass = Dummy4Config
119class QuantumGraphTestCase(unittest.TestCase):
120 """Tests the various functions of a quantum graph"""
122 def setUp(self):
123 config = Config(
124 {
125 "version": 1,
126 "skypix": {
127 "common": "htm7",
128 "htm": {
129 "class": "lsst.sphgeom.HtmPixelization",
130 "max_level": 24,
131 },
132 },
133 "elements": {
134 "A": {
135 "keys": [
136 {
137 "name": "id",
138 "type": "int",
139 }
140 ],
141 "storage": {
142 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
143 },
144 },
145 "B": {
146 "keys": [
147 {
148 "name": "id",
149 "type": "int",
150 }
151 ],
152 "storage": {
153 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
154 },
155 },
156 },
157 "packers": {},
158 }
159 )
160 universe = DimensionUniverse(config=config)
161 # need to make a mapping of TaskDef to set of quantum
162 quantumMap = {}
163 tasks = []
164 for task, label in (
165 (Dummy1PipelineTask, "R"),
166 (Dummy2PipelineTask, "S"),
167 (Dummy3PipelineTask, "T"),
168 (Dummy4PipelineTask, "U"),
169 ):
170 config = task.ConfigClass()
171 taskDef = TaskDef(get_full_type_name(task), config, task, label)
172 tasks.append(taskDef)
173 quantumSet = set()
174 connections = taskDef.connections
175 for a, b in ((1, 2), (3, 4)):
176 if connections.initInputs:
177 initInputDSType = DatasetType(
178 connections.initInput.name,
179 tuple(),
180 storageClass=connections.initInput.storageClass,
181 universe=universe,
182 )
183 initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))]
184 else:
185 initRefs = None
186 inputDSType = DatasetType(
187 connections.input.name,
188 connections.input.dimensions,
189 storageClass=connections.input.storageClass,
190 universe=universe,
191 )
192 inputRefs = [
193 DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
194 ]
195 outputDSType = DatasetType(
196 connections.output.name,
197 connections.output.dimensions,
198 storageClass=connections.output.storageClass,
199 universe=universe,
200 )
201 outputRefs = [
202 DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
203 ]
204 quantumSet.add(
205 Quantum(
206 taskName=task.__qualname__,
207 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
208 taskClass=task,
209 initInputs=initRefs,
210 inputs={inputDSType: inputRefs},
211 outputs={outputDSType: outputRefs},
212 )
213 )
214 quantumMap[taskDef] = quantumSet
215 self.tasks = tasks
216 self.quantumMap = quantumMap
217 self.qGraph = QuantumGraph(quantumMap, metadata=METADATA)
218 self.universe = universe
220 def testTaskGraph(self):
221 for taskDef in self.quantumMap.keys():
222 self.assertIn(taskDef, self.qGraph.taskGraph)
224 def testGraph(self):
225 graphSet = {q.quantum for q in self.qGraph.graph}
226 for quantum in chain.from_iterable(self.quantumMap.values()):
227 self.assertIn(quantum, graphSet)
229 def testGetQuantumNodeByNodeId(self):
230 inputQuanta = tuple(self.qGraph.inputQuanta)
231 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
232 self.assertEqual(node, inputQuanta[0])
233 wrongNode = uuid.uuid4()
234 with self.assertRaises(KeyError):
235 self.qGraph.getQuantumNodeByNodeId(wrongNode)
237 def testPickle(self):
238 stringify = pickle.dumps(self.qGraph)
239 restore: QuantumGraph = pickle.loads(stringify)
240 self.assertEqual(self.qGraph, restore)
242 def testInputQuanta(self):
243 inputs = {q.quantum for q in self.qGraph.inputQuanta}
244 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs)
246 def testOutputtQuanta(self):
247 outputs = {q.quantum for q in self.qGraph.outputQuanta}
248 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs)
250 def testLength(self):
251 self.assertEqual(len(self.qGraph), 2 * len(self.tasks))
253 def testGetQuantaForTask(self):
254 for task in self.tasks:
255 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
257 def testGetNodesForTask(self):
258 for task in self.tasks:
259 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task)
260 quanta_in_node = set(n.quantum for n in nodes)
261 self.assertEqual(quanta_in_node, self.quantumMap[task])
263 def testFindTasksWithInput(self):
264 self.assertEqual(
265 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1]
266 )
268 def testFindTasksWithOutput(self):
269 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
271 def testTaskWithDSType(self):
272 self.assertEqual(
273 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2])
274 )
276 def testFindTaskDefByName(self):
277 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0])
279 def testFindTaskDefByLabel(self):
280 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0])
282 def testFindQuantaWIthDSType(self):
283 self.assertEqual(
284 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]]
285 )
287 def testAllDatasetTypes(self):
288 allDatasetTypes = set(self.qGraph.allDatasetTypes)
289 truth = set()
290 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections):
291 for connection in conClass.allConnections.values(): # type: ignore
292 if not isinstance(connection, cT.InitOutput):
293 truth.add(connection.name)
294 self.assertEqual(allDatasetTypes, truth)
296 def testSubset(self):
297 allNodes = list(self.qGraph)
298 subset = self.qGraph.subset(allNodes[0])
299 self.assertEqual(len(subset), 1)
300 subsetList = list(subset)
301 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
302 self.assertEqual(self.qGraph._buildId, subset._buildId)
304 def testSubsetToConnected(self):
305 # False because there are two quantum chains for two distinct sets of
306 # dimensions
307 self.assertFalse(self.qGraph.isConnected)
309 connectedGraphs = self.qGraph.subsetToConnected()
310 self.assertEqual(len(connectedGraphs), 4)
311 self.assertTrue(connectedGraphs[0].isConnected)
312 self.assertTrue(connectedGraphs[1].isConnected)
313 self.assertTrue(connectedGraphs[2].isConnected)
314 self.assertTrue(connectedGraphs[3].isConnected)
316 # Split out task[3] because it is expected to be on its own
317 for cg in connectedGraphs:
318 if self.tasks[3] in cg.taskGraph:
319 self.assertEqual(len(cg), 1)
320 else:
321 self.assertEqual(len(cg), 3)
323 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
325 count = 0
326 for node in self.qGraph:
327 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
328 count += 1
329 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
330 count += 1
331 if connectedGraphs[2].checkQuantumInGraph(node.quantum):
332 count += 1
333 if connectedGraphs[3].checkQuantumInGraph(node.quantum):
334 count += 1
335 self.assertEqual(len(self.qGraph), count)
337 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs}
338 for setLen, tskSet in taskSets.items():
339 if setLen == 3:
340 self.assertEqual(set(self.tasks[:-1]), tskSet)
341 elif setLen == 1:
342 self.assertEqual({self.tasks[-1]}, tskSet)
343 for cg in connectedGraphs:
344 if len(cg.taskGraph) == 1:
345 continue
346 allNodes = list(cg)
347 node = cg.determineInputsToQuantumNode(allNodes[1])
348 self.assertEqual(set([allNodes[0]]), node)
349 node = cg.determineInputsToQuantumNode(allNodes[1])
350 self.assertEqual(set([allNodes[0]]), node)
352 def testDetermineOutputsOfQuantumNode(self):
353 testNodes = self.qGraph.getNodesForTask(self.tasks[0])
354 matchNodes = self.qGraph.getNodesForTask(self.tasks[1])
355 connections = set()
356 for node in testNodes:
357 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node))
358 self.assertEqual(matchNodes, connections)
360 def testDetermineConnectionsOfQuantum(self):
361 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
362 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2])
363 # outputs contain nodes tested for because it is a complete graph
364 matchNodes |= set(testNodes)
365 connections = set()
366 for node in testNodes:
367 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node))
368 self.assertEqual(matchNodes, connections)
370 def testDetermineAnsestorsOfQuantumNode(self):
371 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
372 matchNodes = self.qGraph.getNodesForTask(self.tasks[0])
373 matchNodes |= set(testNodes)
374 connections = set()
375 for node in testNodes:
376 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node))
377 self.assertEqual(matchNodes, connections)
379 def testFindCycle(self):
380 self.assertFalse(self.qGraph.findCycle())
382 def testSaveLoad(self):
383 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile:
384 self.qGraph.save(tmpFile)
385 tmpFile.seek(0)
386 restore = QuantumGraph.load(tmpFile, self.universe)
387 self.assertEqual(self.qGraph, restore)
388 # Load in just one node
389 tmpFile.seek(0)
390 nodeId = [n.nodeId for n in self.qGraph][0]
391 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,))
392 self.assertEqual(len(restoreSub), 1)
393 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
395 def testSaveLoadUri(self):
396 uri = None
397 try:
398 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile:
399 uri = tmpFile.name
400 self.qGraph.saveUri(uri)
401 restore = QuantumGraph.loadUri(uri, self.universe)
402 self.assertEqual(restore.metadata, METADATA)
403 self.assertEqual(self.qGraph, restore)
404 nodeNumberId = random.randint(0, len(self.qGraph) - 1)
405 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId]
406 restoreSub = QuantumGraph.loadUri(
407 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId
408 )
409 self.assertEqual(len(restoreSub), 1)
410 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber))
411 # verify that more than one node works
412 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
413 # ensure it is a different node number
414 while nodeNumberId2 == nodeNumberId:
415 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
416 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2]
417 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2))
418 self.assertEqual(len(restoreSub), 2)
419 self.assertEqual(
420 set(restoreSub),
421 set(
422 (
423 restore.getQuantumNodeByNodeId(nodeNumber),
424 restore.getQuantumNodeByNodeId(nodeNumber2),
425 )
426 ),
427 )
428 # verify an error when requesting a non existant node number
429 with self.assertRaises(ValueError):
430 QuantumGraph.loadUri(uri, self.universe, nodes=(99,))
432 # verify a graphID that does not match will be an error
433 with self.assertRaises(ValueError):
434 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT")
436 except Exception as e:
437 raise e
438 finally:
439 if uri is not None:
440 os.remove(uri)
442 with self.assertRaises(TypeError):
443 self.qGraph.saveUri("test.notgraph")
445 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!")
446 @mock_s3
447 def testSaveLoadUriS3(self):
448 # Test loading a quantum graph from an mock s3 store
449 conn = boto3.resource("s3", region_name="us-east-1")
450 conn.create_bucket(Bucket="testBucket")
451 uri = "s3://testBucket/qgraph.qgraph"
452 self.qGraph.saveUri(uri)
453 restore = QuantumGraph.loadUri(uri, self.universe)
454 self.assertEqual(self.qGraph, restore)
455 nodeId = list(self.qGraph)[0].nodeId
456 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeId,))
457 self.assertEqual(len(restoreSub), 1)
458 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
460 def testContains(self):
461 firstNode = next(iter(self.qGraph))
462 self.assertIn(firstNode, self.qGraph)
465class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
466 pass
469def setup_module(module):
470 lsst.utils.tests.init()
473if __name__ == "__main__": 473 ↛ 474line 473 didn't jump to line 474, because the condition on line 473 was never true
474 lsst.utils.tests.init()
475 unittest.main()