Coverage for tests/test_quantumGraph.py: 29%
288 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-20 02:51 -0700
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-20 02:51 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import os
23import pickle
24import random
25import tempfile
26import unittest
27import uuid
28from itertools import chain
29from typing import Iterable
31import lsst.pipe.base.connectionTypes as cT
32import lsst.utils.tests
33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum
34from lsst.pex.config import Field
35from lsst.pipe.base import (
36 DatasetTypeName,
37 PipelineTask,
38 PipelineTaskConfig,
39 PipelineTaskConnections,
40 QuantumGraph,
41 TaskDef,
42)
43from lsst.pipe.base.graph.quantumNode import QuantumNode
44from lsst.utils.introspection import get_full_type_name
46try:
47 import boto3
48 from moto import mock_s3
49except ImportError:
50 boto3 = None
52 def mock_s3(cls):
53 """A no-op decorator in case moto mock_s3 can not be imported."""
54 return cls
57METADATA = {"a": [1, 2, 3]}
60class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
61 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
62 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
63 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
66class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
67 conf1 = Field(dtype=int, default=1, doc="dummy config")
70class Dummy1PipelineTask(PipelineTask):
71 ConfigClass = Dummy1Config
74class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
75 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
76 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
77 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
78 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
81class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
82 conf1 = Field(dtype=int, default=1, doc="dummy config")
85class Dummy2PipelineTask(PipelineTask):
86 ConfigClass = Dummy2Config
89class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
90 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
91 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a")
92 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
93 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
96class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
97 conf1 = Field(dtype=int, default=1, doc="dummy config")
100class Dummy3PipelineTask(PipelineTask):
101 ConfigClass = Dummy3Config
104# Test if a Task that does not interact with the other Tasks works fine in
105# the graph.
106class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")):
107 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
108 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
111class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections):
112 conf1 = Field(dtype=int, default=1, doc="dummy config")
115class Dummy4PipelineTask(PipelineTask):
116 ConfigClass = Dummy4Config
119class QuantumGraphTestCase(unittest.TestCase):
120 """Tests the various functions of a quantum graph"""
122 def setUp(self):
123 self.config = Config(
124 {
125 "version": 1,
126 "namespace": "pipe_base_test",
127 "skypix": {
128 "common": "htm7",
129 "htm": {
130 "class": "lsst.sphgeom.HtmPixelization",
131 "max_level": 24,
132 },
133 },
134 "elements": {
135 "A": {
136 "keys": [
137 {
138 "name": "id",
139 "type": "int",
140 }
141 ],
142 "storage": {
143 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
144 },
145 },
146 "B": {
147 "keys": [
148 {
149 "name": "id",
150 "type": "int",
151 }
152 ],
153 "storage": {
154 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
155 },
156 },
157 },
158 "packers": {},
159 }
160 )
161 universe = DimensionUniverse(config=self.config)
162 # need to make a mapping of TaskDef to set of quantum
163 quantumMap = {}
164 tasks = []
165 for task, label in (
166 (Dummy1PipelineTask, "R"),
167 (Dummy2PipelineTask, "S"),
168 (Dummy3PipelineTask, "T"),
169 (Dummy4PipelineTask, "U"),
170 ):
171 config = task.ConfigClass()
172 taskDef = TaskDef(get_full_type_name(task), config, task, label)
173 tasks.append(taskDef)
174 quantumSet = set()
175 connections = taskDef.connections
176 for a, b in ((1, 2), (3, 4)):
177 if connections.initInputs:
178 initInputDSType = DatasetType(
179 connections.initInput.name,
180 tuple(),
181 storageClass=connections.initInput.storageClass,
182 universe=universe,
183 )
184 initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))]
185 else:
186 initRefs = None
187 inputDSType = DatasetType(
188 connections.input.name,
189 connections.input.dimensions,
190 storageClass=connections.input.storageClass,
191 universe=universe,
192 )
193 inputRefs = [
194 DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
195 ]
196 outputDSType = DatasetType(
197 connections.output.name,
198 connections.output.dimensions,
199 storageClass=connections.output.storageClass,
200 universe=universe,
201 )
202 outputRefs = [
203 DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
204 ]
205 quantumSet.add(
206 Quantum(
207 taskName=task.__qualname__,
208 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
209 taskClass=task,
210 initInputs=initRefs,
211 inputs={inputDSType: inputRefs},
212 outputs={outputDSType: outputRefs},
213 )
214 )
215 quantumMap[taskDef] = quantumSet
216 self.tasks = tasks
217 self.quantumMap = quantumMap
218 self.qGraph = QuantumGraph(quantumMap, metadata=METADATA, universe=universe)
219 self.universe = universe
221 def testTaskGraph(self):
222 for taskDef in self.quantumMap.keys():
223 self.assertIn(taskDef, self.qGraph.taskGraph)
225 def testGraph(self):
226 graphSet = {q.quantum for q in self.qGraph.graph}
227 for quantum in chain.from_iterable(self.quantumMap.values()):
228 self.assertIn(quantum, graphSet)
230 def testGetQuantumNodeByNodeId(self):
231 inputQuanta = tuple(self.qGraph.inputQuanta)
232 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
233 self.assertEqual(node, inputQuanta[0])
234 wrongNode = uuid.uuid4()
235 with self.assertRaises(KeyError):
236 self.qGraph.getQuantumNodeByNodeId(wrongNode)
238 def testPickle(self):
239 stringify = pickle.dumps(self.qGraph)
240 restore: QuantumGraph = pickle.loads(stringify)
241 self.assertEqual(self.qGraph, restore)
243 def testInputQuanta(self):
244 inputs = {q.quantum for q in self.qGraph.inputQuanta}
245 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs)
247 def testOutputtQuanta(self):
248 outputs = {q.quantum for q in self.qGraph.outputQuanta}
249 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs)
251 def testLength(self):
252 self.assertEqual(len(self.qGraph), 2 * len(self.tasks))
254 def testGetQuantaForTask(self):
255 for task in self.tasks:
256 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
258 def testGetNodesForTask(self):
259 for task in self.tasks:
260 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task)
261 quanta_in_node = set(n.quantum for n in nodes)
262 self.assertEqual(quanta_in_node, self.quantumMap[task])
264 def testFindTasksWithInput(self):
265 self.assertEqual(
266 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1]
267 )
269 def testFindTasksWithOutput(self):
270 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
272 def testTaskWithDSType(self):
273 self.assertEqual(
274 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2])
275 )
277 def testFindTaskDefByName(self):
278 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0])
280 def testFindTaskDefByLabel(self):
281 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0])
283 def testFindQuantaWIthDSType(self):
284 self.assertEqual(
285 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]]
286 )
288 def testAllDatasetTypes(self):
289 allDatasetTypes = set(self.qGraph.allDatasetTypes)
290 truth = set()
291 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections):
292 for connection in conClass.allConnections.values(): # type: ignore
293 if not isinstance(connection, cT.InitOutput):
294 truth.add(connection.name)
295 self.assertEqual(allDatasetTypes, truth)
297 def testSubset(self):
298 allNodes = list(self.qGraph)
299 subset = self.qGraph.subset(allNodes[0])
300 self.assertEqual(len(subset), 1)
301 subsetList = list(subset)
302 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
303 self.assertEqual(self.qGraph._buildId, subset._buildId)
305 def testSubsetToConnected(self):
306 # False because there are two quantum chains for two distinct sets of
307 # dimensions
308 self.assertFalse(self.qGraph.isConnected)
310 connectedGraphs = self.qGraph.subsetToConnected()
311 self.assertEqual(len(connectedGraphs), 4)
312 self.assertTrue(connectedGraphs[0].isConnected)
313 self.assertTrue(connectedGraphs[1].isConnected)
314 self.assertTrue(connectedGraphs[2].isConnected)
315 self.assertTrue(connectedGraphs[3].isConnected)
317 # Split out task[3] because it is expected to be on its own
318 for cg in connectedGraphs:
319 if self.tasks[3] in cg.taskGraph:
320 self.assertEqual(len(cg), 1)
321 else:
322 self.assertEqual(len(cg), 3)
324 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
326 count = 0
327 for node in self.qGraph:
328 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
329 count += 1
330 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
331 count += 1
332 if connectedGraphs[2].checkQuantumInGraph(node.quantum):
333 count += 1
334 if connectedGraphs[3].checkQuantumInGraph(node.quantum):
335 count += 1
336 self.assertEqual(len(self.qGraph), count)
338 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs}
339 for setLen, tskSet in taskSets.items():
340 if setLen == 3:
341 self.assertEqual(set(self.tasks[:-1]), tskSet)
342 elif setLen == 1:
343 self.assertEqual({self.tasks[-1]}, tskSet)
344 for cg in connectedGraphs:
345 if len(cg.taskGraph) == 1:
346 continue
347 allNodes = list(cg)
348 node = cg.determineInputsToQuantumNode(allNodes[1])
349 self.assertEqual(set([allNodes[0]]), node)
350 node = cg.determineInputsToQuantumNode(allNodes[1])
351 self.assertEqual(set([allNodes[0]]), node)
353 def testDetermineOutputsOfQuantumNode(self):
354 testNodes = self.qGraph.getNodesForTask(self.tasks[0])
355 matchNodes = self.qGraph.getNodesForTask(self.tasks[1])
356 connections = set()
357 for node in testNodes:
358 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node))
359 self.assertEqual(matchNodes, connections)
361 def testDetermineConnectionsOfQuantum(self):
362 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
363 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2])
364 # outputs contain nodes tested for because it is a complete graph
365 matchNodes |= set(testNodes)
366 connections = set()
367 for node in testNodes:
368 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node))
369 self.assertEqual(matchNodes, connections)
371 def testDetermineAnsestorsOfQuantumNode(self):
372 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
373 matchNodes = self.qGraph.getNodesForTask(self.tasks[0])
374 matchNodes |= set(testNodes)
375 connections = set()
376 for node in testNodes:
377 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node))
378 self.assertEqual(matchNodes, connections)
380 def testFindCycle(self):
381 self.assertFalse(self.qGraph.findCycle())
383 def testSaveLoad(self):
384 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile:
385 self.qGraph.save(tmpFile)
386 tmpFile.seek(0)
387 restore = QuantumGraph.load(tmpFile, self.universe)
388 self.assertEqual(self.qGraph, restore)
389 # Load in just one node
390 tmpFile.seek(0)
391 nodeId = [n.nodeId for n in self.qGraph][0]
392 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,))
393 self.assertEqual(len(restoreSub), 1)
394 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
396 # Different universes.
397 tmpFile.seek(0)
398 different_config = self.config.copy()
399 different_config["version"] = 1_000_000
400 different_universe = DimensionUniverse(config=different_config)
401 with self.assertLogs("lsst.daf.butler", "INFO"):
402 QuantumGraph.load(tmpFile, different_universe)
404 different_config["namespace"] = "incompatible"
405 different_universe = DimensionUniverse(config=different_config)
406 print("Trying with uni ", different_universe)
407 tmpFile.seek(0)
408 with self.assertRaises(RuntimeError) as cm:
409 QuantumGraph.load(tmpFile, different_universe)
410 self.assertIn("not compatible with", str(cm.exception))
412 def testSaveLoadUri(self):
413 uri = None
414 try:
415 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile:
416 uri = tmpFile.name
417 self.qGraph.saveUri(uri)
418 restore = QuantumGraph.loadUri(uri)
419 self.assertEqual(restore.metadata, METADATA)
420 self.assertEqual(self.qGraph, restore)
421 nodeNumberId = random.randint(0, len(self.qGraph) - 1)
422 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId]
423 restoreSub = QuantumGraph.loadUri(
424 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId
425 )
426 self.assertEqual(len(restoreSub), 1)
427 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber))
428 # verify that more than one node works
429 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
430 # ensure it is a different node number
431 while nodeNumberId2 == nodeNumberId:
432 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
433 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2]
434 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2))
435 self.assertEqual(len(restoreSub), 2)
436 self.assertEqual(
437 set(restoreSub),
438 set(
439 (
440 restore.getQuantumNodeByNodeId(nodeNumber),
441 restore.getQuantumNodeByNodeId(nodeNumber2),
442 )
443 ),
444 )
445 # verify an error when requesting a non existant node number
446 with self.assertRaises(ValueError):
447 QuantumGraph.loadUri(uri, self.universe, nodes=(99,))
449 # verify a graphID that does not match will be an error
450 with self.assertRaises(ValueError):
451 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT")
453 except Exception as e:
454 raise e
455 finally:
456 if uri is not None:
457 os.remove(uri)
459 with self.assertRaises(TypeError):
460 self.qGraph.saveUri("test.notgraph")
462 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!")
463 @mock_s3
464 def testSaveLoadUriS3(self):
465 # Test loading a quantum graph from an mock s3 store
466 conn = boto3.resource("s3", region_name="us-east-1")
467 conn.create_bucket(Bucket="testBucket")
468 uri = "s3://testBucket/qgraph.qgraph"
469 self.qGraph.saveUri(uri)
470 restore = QuantumGraph.loadUri(uri)
471 self.assertEqual(self.qGraph, restore)
472 nodeId = list(self.qGraph)[0].nodeId
473 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeId,))
474 self.assertEqual(len(restoreSub), 1)
475 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
477 def testContains(self):
478 firstNode = next(iter(self.qGraph))
479 self.assertIn(firstNode, self.qGraph)
481 def testDimensionUniverseInSave(self):
482 _, header = self.qGraph._buildSaveObject(returnHeader=True)
483 # type ignore because buildSaveObject does not have method overload
484 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore
487class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
488 pass
491def setup_module(module):
492 lsst.utils.tests.init()
495if __name__ == "__main__": 495 ↛ 496line 495 didn't jump to line 496, because the condition on line 495 was never true
496 lsst.utils.tests.init()
497 unittest.main()