Coverage for tests/test_quantumGraph.py: 22%
309 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-16 02:02 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-16 02:02 -0800
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import os
23import pickle
24import random
25import tempfile
26import unittest
27import uuid
28from itertools import chain
29from typing import Iterable
31import lsst.pipe.base.connectionTypes as cT
32import lsst.utils.tests
33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum
34from lsst.pex.config import Field
35from lsst.pipe.base import (
36 DatasetTypeName,
37 PipelineTask,
38 PipelineTaskConfig,
39 PipelineTaskConnections,
40 QuantumGraph,
41 TaskDef,
42)
43from lsst.pipe.base.graph.quantumNode import QuantumNode
44from lsst.utils.introspection import get_full_type_name
46try:
47 import boto3
48 from moto import mock_s3
49except ImportError:
50 boto3 = None
52 def mock_s3(cls):
53 """A no-op decorator in case moto mock_s3 can not be imported."""
54 return cls
57METADATA = {"a": [1, 2, 3]}
60class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
61 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
62 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
63 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
66class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
67 conf1 = Field(dtype=int, default=1, doc="dummy config")
70class Dummy1PipelineTask(PipelineTask):
71 ConfigClass = Dummy1Config
74class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
75 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
76 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
77 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
78 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
81class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
82 conf1 = Field(dtype=int, default=1, doc="dummy config")
85class Dummy2PipelineTask(PipelineTask):
86 ConfigClass = Dummy2Config
89class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
90 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
91 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a")
92 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
93 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
96class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
97 conf1 = Field(dtype=int, default=1, doc="dummy config")
100class Dummy3PipelineTask(PipelineTask):
101 ConfigClass = Dummy3Config
104# Test if a Task that does not interact with the other Tasks works fine in
105# the graph.
106class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")):
107 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
108 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
111class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections):
112 conf1 = Field(dtype=int, default=1, doc="dummy config")
115class Dummy4PipelineTask(PipelineTask):
116 ConfigClass = Dummy4Config
119class QuantumGraphTestCase(unittest.TestCase):
120 """Tests the various functions of a quantum graph"""
122 def setUp(self):
123 self.config = Config(
124 {
125 "version": 1,
126 "namespace": "pipe_base_test",
127 "skypix": {
128 "common": "htm7",
129 "htm": {
130 "class": "lsst.sphgeom.HtmPixelization",
131 "max_level": 24,
132 },
133 },
134 "elements": {
135 "A": {
136 "keys": [
137 {
138 "name": "id",
139 "type": "int",
140 }
141 ],
142 "storage": {
143 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
144 },
145 },
146 "B": {
147 "keys": [
148 {
149 "name": "id",
150 "type": "int",
151 }
152 ],
153 "storage": {
154 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
155 },
156 },
157 },
158 "packers": {},
159 }
160 )
161 universe = DimensionUniverse(config=self.config)
163 def _makeDatasetType(connection):
164 return DatasetType(
165 connection.name,
166 getattr(connection, "dimensions", ()),
167 storageClass=connection.storageClass,
168 universe=universe,
169 )
171 # need to make a mapping of TaskDef to set of quantum
172 quantumMap = {}
173 tasks = []
174 initInputs = {}
175 initOutputs = {}
176 for task, label in (
177 (Dummy1PipelineTask, "R"),
178 (Dummy2PipelineTask, "S"),
179 (Dummy3PipelineTask, "T"),
180 (Dummy4PipelineTask, "U"),
181 ):
182 config = task.ConfigClass()
183 taskDef = TaskDef(get_full_type_name(task), config, task, label)
184 tasks.append(taskDef)
185 quantumSet = set()
186 connections = taskDef.connections
187 if connections.initInputs:
188 initInputDSType = _makeDatasetType(connections.initInput)
189 initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))]
190 initInputs[taskDef] = initRefs
191 else:
192 initRefs = None
193 if connections.initOutputs:
194 initOutputDSType = _makeDatasetType(connections.initOutput)
195 initRefs = [DatasetRef(initOutputDSType, DataCoordinate.makeEmpty(universe))]
196 initOutputs[taskDef] = initRefs
197 inputDSType = _makeDatasetType(connections.input)
198 outputDSType = _makeDatasetType(connections.output)
199 for a, b in ((1, 2), (3, 4)):
200 inputRefs = [
201 DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
202 ]
203 outputRefs = [
204 DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
205 ]
206 quantumSet.add(
207 Quantum(
208 taskName=task.__qualname__,
209 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
210 taskClass=task,
211 initInputs=initRefs,
212 inputs={inputDSType: inputRefs},
213 outputs={outputDSType: outputRefs},
214 )
215 )
216 quantumMap[taskDef] = quantumSet
217 self.tasks = tasks
218 self.quantumMap = quantumMap
219 self.qGraph = QuantumGraph(
220 quantumMap, metadata=METADATA, universe=universe, initInputs=initInputs, initOutputs=initOutputs
221 )
222 self.universe = universe
224 def testTaskGraph(self):
225 for taskDef in self.quantumMap.keys():
226 self.assertIn(taskDef, self.qGraph.taskGraph)
228 def testGraph(self):
229 graphSet = {q.quantum for q in self.qGraph.graph}
230 for quantum in chain.from_iterable(self.quantumMap.values()):
231 self.assertIn(quantum, graphSet)
233 def testGetQuantumNodeByNodeId(self):
234 inputQuanta = tuple(self.qGraph.inputQuanta)
235 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
236 self.assertEqual(node, inputQuanta[0])
237 wrongNode = uuid.uuid4()
238 with self.assertRaises(KeyError):
239 self.qGraph.getQuantumNodeByNodeId(wrongNode)
241 def testPickle(self):
242 stringify = pickle.dumps(self.qGraph)
243 restore: QuantumGraph = pickle.loads(stringify)
244 self.assertEqual(self.qGraph, restore)
246 def testInputQuanta(self):
247 inputs = {q.quantum for q in self.qGraph.inputQuanta}
248 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs)
250 def testOutputtQuanta(self):
251 outputs = {q.quantum for q in self.qGraph.outputQuanta}
252 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs)
254 def testLength(self):
255 self.assertEqual(len(self.qGraph), 2 * len(self.tasks))
257 def testGetQuantaForTask(self):
258 for task in self.tasks:
259 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
261 def testGetNumberOfQuantaForTask(self):
262 for task in self.tasks:
263 self.assertEqual(self.qGraph.getNumberOfQuantaForTask(task), len(self.quantumMap[task]))
265 def testGetNodesForTask(self):
266 for task in self.tasks:
267 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task)
268 quanta_in_node = set(n.quantum for n in nodes)
269 self.assertEqual(quanta_in_node, self.quantumMap[task])
271 def testFindTasksWithInput(self):
272 self.assertEqual(
273 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1]
274 )
276 def testFindTasksWithOutput(self):
277 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
279 def testTaskWithDSType(self):
280 self.assertEqual(
281 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2])
282 )
284 def testFindTaskDefByName(self):
285 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0])
287 def testFindTaskDefByLabel(self):
288 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0])
290 def testFindQuantaWIthDSType(self):
291 self.assertEqual(
292 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]]
293 )
295 def testAllDatasetTypes(self):
296 allDatasetTypes = set(self.qGraph.allDatasetTypes)
297 truth = set()
298 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections):
299 for connection in conClass.allConnections.values(): # type: ignore
300 if not isinstance(connection, cT.InitOutput):
301 truth.add(connection.name)
302 self.assertEqual(allDatasetTypes, truth)
304 def testSubset(self):
305 allNodes = list(self.qGraph)
306 subset = self.qGraph.subset(allNodes[0])
307 self.assertEqual(len(subset), 1)
308 subsetList = list(subset)
309 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
310 self.assertEqual(self.qGraph._buildId, subset._buildId)
312 def testSubsetToConnected(self):
313 # False because there are two quantum chains for two distinct sets of
314 # dimensions
315 self.assertFalse(self.qGraph.isConnected)
317 connectedGraphs = self.qGraph.subsetToConnected()
318 self.assertEqual(len(connectedGraphs), 4)
319 self.assertTrue(connectedGraphs[0].isConnected)
320 self.assertTrue(connectedGraphs[1].isConnected)
321 self.assertTrue(connectedGraphs[2].isConnected)
322 self.assertTrue(connectedGraphs[3].isConnected)
324 # Split out task[3] because it is expected to be on its own
325 for cg in connectedGraphs:
326 if self.tasks[3] in cg.taskGraph:
327 self.assertEqual(len(cg), 1)
328 else:
329 self.assertEqual(len(cg), 3)
331 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
333 count = 0
334 for node in self.qGraph:
335 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
336 count += 1
337 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
338 count += 1
339 if connectedGraphs[2].checkQuantumInGraph(node.quantum):
340 count += 1
341 if connectedGraphs[3].checkQuantumInGraph(node.quantum):
342 count += 1
343 self.assertEqual(len(self.qGraph), count)
345 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs}
346 for setLen, tskSet in taskSets.items():
347 if setLen == 3:
348 self.assertEqual(set(self.tasks[:-1]), tskSet)
349 elif setLen == 1:
350 self.assertEqual({self.tasks[-1]}, tskSet)
351 for cg in connectedGraphs:
352 if len(cg.taskGraph) == 1:
353 continue
354 allNodes = list(cg)
355 node = cg.determineInputsToQuantumNode(allNodes[1])
356 self.assertEqual(set([allNodes[0]]), node)
357 node = cg.determineInputsToQuantumNode(allNodes[1])
358 self.assertEqual(set([allNodes[0]]), node)
360 def testDetermineOutputsOfQuantumNode(self):
361 testNodes = self.qGraph.getNodesForTask(self.tasks[0])
362 matchNodes = self.qGraph.getNodesForTask(self.tasks[1])
363 connections = set()
364 for node in testNodes:
365 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node))
366 self.assertEqual(matchNodes, connections)
368 def testDetermineConnectionsOfQuantum(self):
369 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
370 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2])
371 # outputs contain nodes tested for because it is a complete graph
372 matchNodes |= set(testNodes)
373 connections = set()
374 for node in testNodes:
375 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node))
376 self.assertEqual(matchNodes, connections)
378 def testDetermineAnsestorsOfQuantumNode(self):
379 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
380 matchNodes = self.qGraph.getNodesForTask(self.tasks[0])
381 matchNodes |= set(testNodes)
382 connections = set()
383 for node in testNodes:
384 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node))
385 self.assertEqual(matchNodes, connections)
387 def testFindCycle(self):
388 self.assertFalse(self.qGraph.findCycle())
390 def testSaveLoad(self):
391 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile:
392 self.qGraph.save(tmpFile)
393 tmpFile.seek(0)
394 restore = QuantumGraph.load(tmpFile, self.universe)
395 self.assertEqual(self.qGraph, restore)
396 # Load in just one node
397 tmpFile.seek(0)
398 nodeId = [n.nodeId for n in self.qGraph][0]
399 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,))
400 self.assertEqual(len(restoreSub), 1)
401 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
402 # Check that InitInput and InitOutput refs are restored correctly.
403 for taskDef in restore.iterTaskGraph():
404 if taskDef.label in ("S", "T"):
405 refs = restore.initInputRefs(taskDef)
406 self.assertIsNotNone(refs)
407 self.assertGreater(len(refs), 0)
408 if taskDef.label in ("R", "S", "T"):
409 refs = restore.initOutputRefs(taskDef)
410 self.assertIsNotNone(refs)
411 self.assertGreater(len(refs), 0)
413 # Different universes.
414 tmpFile.seek(0)
415 different_config = self.config.copy()
416 different_config["version"] = 1_000_000
417 different_universe = DimensionUniverse(config=different_config)
418 with self.assertLogs("lsst.daf.butler", "INFO"):
419 QuantumGraph.load(tmpFile, different_universe)
421 different_config["namespace"] = "incompatible"
422 different_universe = DimensionUniverse(config=different_config)
423 print("Trying with uni ", different_universe)
424 tmpFile.seek(0)
425 with self.assertRaises(RuntimeError) as cm:
426 QuantumGraph.load(tmpFile, different_universe)
427 self.assertIn("not compatible with", str(cm.exception))
429 def testSaveLoadUri(self):
430 uri = None
431 try:
432 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile:
433 uri = tmpFile.name
434 self.qGraph.saveUri(uri)
435 restore = QuantumGraph.loadUri(uri)
436 self.assertEqual(restore.metadata, METADATA)
437 self.assertEqual(self.qGraph, restore)
438 nodeNumberId = random.randint(0, len(self.qGraph) - 1)
439 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId]
440 restoreSub = QuantumGraph.loadUri(
441 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId
442 )
443 self.assertEqual(len(restoreSub), 1)
444 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber))
445 # verify that more than one node works
446 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
447 # ensure it is a different node number
448 while nodeNumberId2 == nodeNumberId:
449 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
450 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2]
451 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2))
452 self.assertEqual(len(restoreSub), 2)
453 self.assertEqual(
454 set(restoreSub),
455 set(
456 (
457 restore.getQuantumNodeByNodeId(nodeNumber),
458 restore.getQuantumNodeByNodeId(nodeNumber2),
459 )
460 ),
461 )
462 # verify an error when requesting a non existant node number
463 with self.assertRaises(ValueError):
464 QuantumGraph.loadUri(uri, self.universe, nodes=(99,))
466 # verify a graphID that does not match will be an error
467 with self.assertRaises(ValueError):
468 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT")
470 except Exception as e:
471 raise e
472 finally:
473 if uri is not None:
474 os.remove(uri)
476 with self.assertRaises(TypeError):
477 self.qGraph.saveUri("test.notgraph")
479 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!")
480 @mock_s3
481 def testSaveLoadUriS3(self):
482 # Test loading a quantum graph from an mock s3 store
483 conn = boto3.resource("s3", region_name="us-east-1")
484 conn.create_bucket(Bucket="testBucket")
485 uri = "s3://testBucket/qgraph.qgraph"
486 self.qGraph.saveUri(uri)
487 restore = QuantumGraph.loadUri(uri)
488 self.assertEqual(self.qGraph, restore)
489 nodeId = list(self.qGraph)[0].nodeId
490 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeId,))
491 self.assertEqual(len(restoreSub), 1)
492 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
494 def testContains(self):
495 firstNode = next(iter(self.qGraph))
496 self.assertIn(firstNode, self.qGraph)
498 def testDimensionUniverseInSave(self):
499 _, header = self.qGraph._buildSaveObject(returnHeader=True)
500 # type ignore because buildSaveObject does not have method overload
501 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore
504class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
505 pass
508def setup_module(module):
509 lsst.utils.tests.init()
512if __name__ == "__main__": 512 ↛ 513line 512 didn't jump to line 513, because the condition on line 512 was never true
513 lsst.utils.tests.init()
514 unittest.main()