Coverage for tests/test_quantumGraph.py : 28%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from itertools import chain
23import os
24import pickle
25import tempfile
26import unittest
27import random
28from lsst.daf.butler import DimensionUniverse
30from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections,
31 DatasetTypeName, IncompatibleGraphError)
32import lsst.pipe.base.connectionTypes as cT
33from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config
34from lsst.pex.config import Field
35from lsst.pipe.base.graph.quantumNode import NodeId, BuildId
36import lsst.utils.tests
38try:
39 import boto3
40 from moto import mock_s3
41except ImportError:
42 boto3 = None
44 def mock_s3(cls):
45 """A no-op decorator in case moto mock_s3 can not be imported.
46 """
47 return cls
50class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
51 initOutput = cT.InitOutput(name="Dummy1InitOutput",
52 storageClass="ExposureF",
53 doc="n/a")
54 input = cT.Input(name="Dummy1Input",
55 storageClass="ExposureF",
56 doc="n/a",
57 dimensions=("A", "B"))
58 output = cT.Output(name="Dummy1Output",
59 storageClass="ExposureF",
60 doc="n/a",
61 dimensions=("A", "B"))
64class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
65 conf1 = Field(dtype=int, default=1, doc="dummy config")
68class Dummy1PipelineTask(PipelineTask):
69 ConfigClass = Dummy1Config
72class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
73 initInput = cT.InitInput(name="Dummy1InitOutput",
74 storageClass="ExposureF",
75 doc="n/a")
76 initOutput = cT.InitOutput(name="Dummy2InitOutput",
77 storageClass="ExposureF",
78 doc="n/a")
79 input = cT.Input(name="Dummy1Output",
80 storageClass="ExposureF",
81 doc="n/a",
82 dimensions=("A", "B"))
83 output = cT.Output(name="Dummy2Output",
84 storageClass="ExposureF",
85 doc="n/a",
86 dimensions=("A", "B"))
89class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
90 conf1 = Field(dtype=int, default=1, doc="dummy config")
93class Dummy2PipelineTask(PipelineTask):
94 ConfigClass = Dummy2Config
97class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
98 initInput = cT.InitInput(name="Dummy2InitOutput",
99 storageClass="ExposureF",
100 doc="n/a")
101 initOutput = cT.InitOutput(name="Dummy3InitOutput",
102 storageClass="ExposureF",
103 doc="n/a")
104 input = cT.Input(name="Dummy2Output",
105 storageClass="ExposureF",
106 doc="n/a",
107 dimensions=("A", "B"))
108 output = cT.Output(name="Dummy3Output",
109 storageClass="ExposureF",
110 doc="n/a",
111 dimensions=("A", "B"))
114class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
115 conf1 = Field(dtype=int, default=1, doc="dummy config")
118class Dummy3PipelineTask(PipelineTask):
119 ConfigClass = Dummy3Config
122class QuantumGraphTestCase(unittest.TestCase):
123 """Tests the various functions of a quantum graph
124 """
125 def setUp(self):
126 config = Config({
127 "version": 1,
128 "skypix": {
129 "common": "htm7",
130 "htm": {
131 "class": "lsst.sphgeom.HtmPixelization",
132 "max_level": 24,
133 }
134 },
135 "elements": {
136 "A": {
137 "keys": [{
138 "name": "id",
139 "type": "int",
140 }],
141 "storage": {
142 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
143 },
144 },
145 "B": {
146 "keys": [{
147 "name": "id",
148 "type": "int",
149 }],
150 "storage": {
151 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
152 },
153 }
154 },
155 "packers": {}
156 })
157 universe = DimensionUniverse(config=config)
158 # need to make a mapping of TaskDef to set of quantum
159 quantumMap = {}
160 tasks = []
161 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")):
162 config = task.ConfigClass()
163 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label)
164 tasks.append(taskDef)
165 quantumSet = set()
166 connections = taskDef.connections
167 for a, b in ((1, 2), (3, 4)):
168 if connections.initInputs:
169 initInputDSType = DatasetType(connections.initInput.name,
170 tuple(),
171 storageClass=connections.initInput.storageClass,
172 universe=universe)
173 initRefs = [DatasetRef(initInputDSType,
174 DataCoordinate.makeEmpty(universe))]
175 else:
176 initRefs = None
177 inputDSType = DatasetType(connections.input.name,
178 connections.input.dimensions,
179 storageClass=connections.input.storageClass,
180 universe=universe,
181 )
182 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b},
183 universe=universe))]
184 outputDSType = DatasetType(connections.output.name,
185 connections.output.dimensions,
186 storageClass=connections.output.storageClass,
187 universe=universe,
188 )
189 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b},
190 universe=universe))]
191 quantumSet.add(
192 Quantum(taskName=task.__qualname__,
193 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
194 taskClass=task,
195 initInputs=initRefs,
196 inputs={inputDSType: inputRefs},
197 outputs={outputDSType: outputRefs}
198 )
199 )
200 quantumMap[taskDef] = quantumSet
201 self.tasks = tasks
202 self.quantumMap = quantumMap
203 self.qGraph = QuantumGraph(quantumMap)
204 self.universe = universe
206 def _cleanGraphs(self, graph1, graph2):
207 # This is a hack for the unit test since the qualified name will be
208 # different as it will be __main__ here, but qualified to the
209 # unittest module name when restored
210 # Updates in place
211 for saved, loaded in zip(graph1._quanta.keys(),
212 graph2._quanta.keys()):
213 saved.taskName = saved.taskName.split('.')[-1]
214 loaded.taskName = loaded.taskName.split('.')[-1]
216 def testTaskGraph(self):
217 for taskDef in self.quantumMap.keys():
218 self.assertIn(taskDef, self.qGraph.taskGraph)
220 def testGraph(self):
221 graphSet = {q.quantum for q in self.qGraph.graph}
222 for quantum in chain.from_iterable(self.quantumMap.values()):
223 self.assertIn(quantum, graphSet)
225 def testGetQuantumNodeByNodeId(self):
226 inputQuanta = tuple(self.qGraph.inputQuanta)
227 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
228 self.assertEqual(node, inputQuanta[0])
229 wrongNode = NodeId(15, BuildId("alternative build Id"))
230 with self.assertRaises(IncompatibleGraphError):
231 self.qGraph.getQuantumNodeByNodeId(wrongNode)
233 def testPickle(self):
234 stringify = pickle.dumps(self.qGraph)
235 restore: QuantumGraph = pickle.loads(stringify)
236 self._cleanGraphs(self.qGraph, restore)
237 self.assertEqual(self.qGraph, restore)
239 def testInputQuanta(self):
240 inputs = {q.quantum for q in self.qGraph.inputQuanta}
241 self.assertEqual(self.quantumMap[self.tasks[0]], inputs)
243 def testOutputtQuanta(self):
244 outputs = {q.quantum for q in self.qGraph.outputQuanta}
245 self.assertEqual(self.quantumMap[self.tasks[-1]], outputs)
247 def testLength(self):
248 self.assertEqual(len(self.qGraph), 6)
250 def testGetQuantaForTask(self):
251 for task in self.tasks:
252 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
254 def testFindTasksWithInput(self):
255 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0],
256 self.tasks[1])
258 def testFindTasksWithOutput(self):
259 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
261 def testTaskWithDSType(self):
262 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))),
263 set(self.tasks[:2]))
265 def testFindTaskDefByName(self):
266 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0],
267 self.tasks[0])
269 def testFindTaskDefByLabel(self):
270 self.assertEqual(self.qGraph.findTaskDefByLabel("R"),
271 self.tasks[0])
273 def testFindQuantaWIthDSType(self):
274 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")),
275 self.quantumMap[self.tasks[0]])
277 def testAllDatasetTypes(self):
278 allDatasetTypes = set(self.qGraph.allDatasetTypes)
279 truth = set()
280 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections):
281 for connection in conClass.allConnections.values(): # type: ignore
282 truth.add(connection.name)
283 self.assertEqual(allDatasetTypes, truth)
285 def testSubset(self):
286 allNodes = list(self.qGraph)
287 subset = self.qGraph.subset(allNodes[0])
288 self.assertEqual(len(subset), 1)
289 subsetList = list(subset)
290 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
291 self.assertEqual(self.qGraph._buildId, subset._buildId)
293 def testIsConnected(self):
294 # False because there are two quantum chains for two distinct sets of
295 # dimensions
296 self.assertFalse(self.qGraph.isConnected)
297 # make a broken subset
298 allNodes = list(self.qGraph)
299 subset = self.qGraph.subset((allNodes[0], allNodes[1]))
300 # True because we subset to only one chain of graphs
301 self.assertTrue(subset.isConnected)
303 def testSubsetToConnected(self):
304 connectedGraphs = self.qGraph.subsetToConnected()
305 self.assertEqual(len(connectedGraphs), 2)
306 self.assertTrue(connectedGraphs[0].isConnected)
307 self.assertTrue(connectedGraphs[1].isConnected)
309 self.assertEqual(len(connectedGraphs[0]), 3)
310 self.assertEqual(len(connectedGraphs[1]), 3)
312 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
314 count = 0
315 for node in self.qGraph:
316 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
317 count += 1
318 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
319 count += 1
320 self.assertEqual(len(self.qGraph), count)
322 self.assertEqual(self.tasks, list(connectedGraphs[0].taskGraph))
323 self.assertEqual(self.tasks, list(connectedGraphs[1].taskGraph))
324 allNodes = list(self.qGraph)
325 node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
326 self.assertEqual(set([allNodes[0]]), node)
327 node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
328 self.assertEqual(set([allNodes[0]]), node)
330 def testDetermineOutputsOfQuantumNode(self):
331 allNodes = list(self.qGraph)
332 node = next(iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1])))
333 self.assertEqual(allNodes[2], node)
335 def testDetermineConnectionsOfQuantum(self):
336 allNodes = list(self.qGraph)
337 connections = self.qGraph.determineConnectionsOfQuantumNode(allNodes[1])
338 self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3])))
340 def testDetermineAnsestorsOfQuantumNode(self):
341 allNodes = list(self.qGraph)
342 ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2])
343 self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3])))
345 def testFindCycle(self):
346 self.assertFalse(self.qGraph.findCycle())
348 def testSaveLoad(self):
349 with tempfile.TemporaryFile(suffix='.qgraph') as tmpFile:
350 self.qGraph.save(tmpFile)
351 tmpFile.seek(0)
352 restore = QuantumGraph.load(tmpFile, self.universe)
353 self._cleanGraphs(self.qGraph, restore)
354 self.assertEqual(self.qGraph, restore)
355 # Load in just one node
356 tmpFile.seek(0)
357 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(0,))
358 self.assertEqual(len(restoreSub), 1)
359 self.assertEqual(list(restoreSub)[0],
360 restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId)))
362 def testSaveLoadUri(self):
363 uri = None
364 try:
365 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile:
366 uri = tmpFile.name
367 self.qGraph.saveUri(uri)
368 restore = QuantumGraph.loadUri(uri, self.universe)
369 self._cleanGraphs(self.qGraph, restore)
370 self.assertEqual(self.qGraph, restore)
371 nodeNumber = random.randint(0, len(self.qGraph)-1)
372 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber,),
373 graphID=self.qGraph._buildId)
374 self.assertEqual(len(restoreSub), 1)
375 self.assertEqual(list(restoreSub)[0],
376 restore.getQuantumNodeByNodeId(NodeId(nodeNumber, restore.graphID)))
377 # verify that more than one node works
378 nodeNumber2 = random.randint(0, len(self.qGraph)-1)
379 # ensure it is a different node number
380 while nodeNumber2 == nodeNumber:
381 nodeNumber2 = random.randint(0, len(self.qGraph)-1)
382 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2))
383 self.assertEqual(len(restoreSub), 2)
384 self.assertEqual(set(restoreSub),
385 set((restore.getQuantumNodeByNodeId(NodeId(nodeNumber, restore._buildId)),
386 restore.getQuantumNodeByNodeId(NodeId(nodeNumber2, restore._buildId)))))
387 # verify an error when requesting a non existant node number
388 with self.assertRaises(ValueError):
389 QuantumGraph.loadUri(uri, self.universe, nodes=(99,))
391 # verify a graphID that does not match will be an error
392 with self.assertRaises(ValueError):
393 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT")
395 except Exception as e:
396 raise e
397 finally:
398 if uri is not None:
399 os.remove(uri)
401 with self.assertRaises(TypeError):
402 self.qGraph.saveUri("test.notgraph")
404 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!")
405 @mock_s3
406 def testSaveLoadUriS3(self):
407 # Test loading a quantum graph from an mock s3 store
408 conn = boto3.resource('s3', region_name="us-east-1")
409 conn.create_bucket(Bucket='testBucket')
410 uri = "s3://testBucket/qgraph.qgraph"
411 self.qGraph.saveUri(uri)
412 restore = QuantumGraph.loadUri(uri, self.universe)
413 self._cleanGraphs(self.qGraph, restore)
414 self.assertEqual(self.qGraph, restore)
415 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(0,))
416 self.assertEqual(len(restoreSub), 1)
417 self.assertEqual(list(restoreSub)[0],
418 restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId)))
420 def testContains(self):
421 firstNode = next(iter(self.qGraph))
422 self.assertIn(firstNode, self.qGraph)
425class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
426 pass
429def setup_module(module):
430 lsst.utils.tests.init()
433if __name__ == "__main__": 433 ↛ 434line 433 didn't jump to line 434, because the condition on line 433 was never true
434 lsst.utils.tests.init()
435 unittest.main()