Coverage for tests/test_quantumGraph.py: 22%
306 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-04 09:40 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-04 09:40 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import os
23import pickle
24import random
25import tempfile
26import unittest
27import uuid
28from itertools import chain
29from typing import Iterable
31import lsst.pipe.base.connectionTypes as cT
32import lsst.utils.tests
33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum
34from lsst.pex.config import Field
35from lsst.pipe.base import (
36 DatasetTypeName,
37 PipelineTask,
38 PipelineTaskConfig,
39 PipelineTaskConnections,
40 QuantumGraph,
41 TaskDef,
42)
43from lsst.pipe.base.graph.quantumNode import QuantumNode
44from lsst.utils.introspection import get_full_type_name
46try:
47 import boto3
48 from moto import mock_s3
49except ImportError:
50 boto3 = None
52 def mock_s3(cls):
53 """A no-op decorator in case moto mock_s3 can not be imported."""
54 return cls
57METADATA = {"a": [1, 2, 3]}
60class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
61 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
62 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
63 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
66class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
67 conf1 = Field(dtype=int, default=1, doc="dummy config")
70class Dummy1PipelineTask(PipelineTask):
71 ConfigClass = Dummy1Config
74class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
75 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
76 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
77 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
78 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
81class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
82 conf1 = Field(dtype=int, default=1, doc="dummy config")
85class Dummy2PipelineTask(PipelineTask):
86 ConfigClass = Dummy2Config
89class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
90 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
91 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a")
92 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
93 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
96class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
97 conf1 = Field(dtype=int, default=1, doc="dummy config")
100class Dummy3PipelineTask(PipelineTask):
101 ConfigClass = Dummy3Config
104# Test if a Task that does not interact with the other Tasks works fine in
105# the graph.
106class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")):
107 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
108 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
111class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections):
112 conf1 = Field(dtype=int, default=1, doc="dummy config")
115class Dummy4PipelineTask(PipelineTask):
116 ConfigClass = Dummy4Config
119class QuantumGraphTestCase(unittest.TestCase):
120 """Tests the various functions of a quantum graph"""
122 def setUp(self):
123 self.config = Config(
124 {
125 "version": 1,
126 "namespace": "pipe_base_test",
127 "skypix": {
128 "common": "htm7",
129 "htm": {
130 "class": "lsst.sphgeom.HtmPixelization",
131 "max_level": 24,
132 },
133 },
134 "elements": {
135 "A": {
136 "keys": [
137 {
138 "name": "id",
139 "type": "int",
140 }
141 ],
142 "storage": {
143 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
144 },
145 },
146 "B": {
147 "keys": [
148 {
149 "name": "id",
150 "type": "int",
151 }
152 ],
153 "storage": {
154 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
155 },
156 },
157 },
158 "packers": {},
159 }
160 )
161 universe = DimensionUniverse(config=self.config)
163 def _makeDatasetType(connection):
164 return DatasetType(
165 connection.name,
166 getattr(connection, "dimensions", ()),
167 storageClass=connection.storageClass,
168 universe=universe,
169 )
171 # need to make a mapping of TaskDef to set of quantum
172 quantumMap = {}
173 tasks = []
174 initInputs = {}
175 initOutputs = {}
176 for task, label in (
177 (Dummy1PipelineTask, "R"),
178 (Dummy2PipelineTask, "S"),
179 (Dummy3PipelineTask, "T"),
180 (Dummy4PipelineTask, "U"),
181 ):
182 config = task.ConfigClass()
183 taskDef = TaskDef(get_full_type_name(task), config, task, label)
184 tasks.append(taskDef)
185 quantumSet = set()
186 connections = taskDef.connections
187 if connections.initInputs:
188 initInputDSType = _makeDatasetType(connections.initInput)
189 initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))]
190 initInputs[taskDef] = initRefs
191 else:
192 initRefs = None
193 if connections.initOutputs:
194 initOutputDSType = _makeDatasetType(connections.initOutput)
195 initRefs = [DatasetRef(initOutputDSType, DataCoordinate.makeEmpty(universe))]
196 initOutputs[taskDef] = initRefs
197 inputDSType = _makeDatasetType(connections.input)
198 outputDSType = _makeDatasetType(connections.output)
199 for a, b in ((1, 2), (3, 4)):
200 inputRefs = [
201 DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
202 ]
203 outputRefs = [
204 DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
205 ]
206 quantumSet.add(
207 Quantum(
208 taskName=task.__qualname__,
209 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
210 taskClass=task,
211 initInputs=initRefs,
212 inputs={inputDSType: inputRefs},
213 outputs={outputDSType: outputRefs},
214 )
215 )
216 quantumMap[taskDef] = quantumSet
217 self.tasks = tasks
218 self.quantumMap = quantumMap
219 self.qGraph = QuantumGraph(
220 quantumMap, metadata=METADATA, universe=universe, initInputs=initInputs, initOutputs=initOutputs
221 )
222 self.universe = universe
224 def testTaskGraph(self):
225 for taskDef in self.quantumMap.keys():
226 self.assertIn(taskDef, self.qGraph.taskGraph)
228 def testGraph(self):
229 graphSet = {q.quantum for q in self.qGraph.graph}
230 for quantum in chain.from_iterable(self.quantumMap.values()):
231 self.assertIn(quantum, graphSet)
233 def testGetQuantumNodeByNodeId(self):
234 inputQuanta = tuple(self.qGraph.inputQuanta)
235 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
236 self.assertEqual(node, inputQuanta[0])
237 wrongNode = uuid.uuid4()
238 with self.assertRaises(KeyError):
239 self.qGraph.getQuantumNodeByNodeId(wrongNode)
241 def testPickle(self):
242 stringify = pickle.dumps(self.qGraph)
243 restore: QuantumGraph = pickle.loads(stringify)
244 self.assertEqual(self.qGraph, restore)
246 def testInputQuanta(self):
247 inputs = {q.quantum for q in self.qGraph.inputQuanta}
248 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs)
250 def testOutputtQuanta(self):
251 outputs = {q.quantum for q in self.qGraph.outputQuanta}
252 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs)
254 def testLength(self):
255 self.assertEqual(len(self.qGraph), 2 * len(self.tasks))
257 def testGetQuantaForTask(self):
258 for task in self.tasks:
259 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
261 def testGetNodesForTask(self):
262 for task in self.tasks:
263 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task)
264 quanta_in_node = set(n.quantum for n in nodes)
265 self.assertEqual(quanta_in_node, self.quantumMap[task])
267 def testFindTasksWithInput(self):
268 self.assertEqual(
269 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1]
270 )
272 def testFindTasksWithOutput(self):
273 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
275 def testTaskWithDSType(self):
276 self.assertEqual(
277 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2])
278 )
280 def testFindTaskDefByName(self):
281 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0])
283 def testFindTaskDefByLabel(self):
284 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0])
286 def testFindQuantaWIthDSType(self):
287 self.assertEqual(
288 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]]
289 )
291 def testAllDatasetTypes(self):
292 allDatasetTypes = set(self.qGraph.allDatasetTypes)
293 truth = set()
294 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections):
295 for connection in conClass.allConnections.values(): # type: ignore
296 if not isinstance(connection, cT.InitOutput):
297 truth.add(connection.name)
298 self.assertEqual(allDatasetTypes, truth)
300 def testSubset(self):
301 allNodes = list(self.qGraph)
302 subset = self.qGraph.subset(allNodes[0])
303 self.assertEqual(len(subset), 1)
304 subsetList = list(subset)
305 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
306 self.assertEqual(self.qGraph._buildId, subset._buildId)
308 def testSubsetToConnected(self):
309 # False because there are two quantum chains for two distinct sets of
310 # dimensions
311 self.assertFalse(self.qGraph.isConnected)
313 connectedGraphs = self.qGraph.subsetToConnected()
314 self.assertEqual(len(connectedGraphs), 4)
315 self.assertTrue(connectedGraphs[0].isConnected)
316 self.assertTrue(connectedGraphs[1].isConnected)
317 self.assertTrue(connectedGraphs[2].isConnected)
318 self.assertTrue(connectedGraphs[3].isConnected)
320 # Split out task[3] because it is expected to be on its own
321 for cg in connectedGraphs:
322 if self.tasks[3] in cg.taskGraph:
323 self.assertEqual(len(cg), 1)
324 else:
325 self.assertEqual(len(cg), 3)
327 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
329 count = 0
330 for node in self.qGraph:
331 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
332 count += 1
333 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
334 count += 1
335 if connectedGraphs[2].checkQuantumInGraph(node.quantum):
336 count += 1
337 if connectedGraphs[3].checkQuantumInGraph(node.quantum):
338 count += 1
339 self.assertEqual(len(self.qGraph), count)
341 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs}
342 for setLen, tskSet in taskSets.items():
343 if setLen == 3:
344 self.assertEqual(set(self.tasks[:-1]), tskSet)
345 elif setLen == 1:
346 self.assertEqual({self.tasks[-1]}, tskSet)
347 for cg in connectedGraphs:
348 if len(cg.taskGraph) == 1:
349 continue
350 allNodes = list(cg)
351 node = cg.determineInputsToQuantumNode(allNodes[1])
352 self.assertEqual(set([allNodes[0]]), node)
353 node = cg.determineInputsToQuantumNode(allNodes[1])
354 self.assertEqual(set([allNodes[0]]), node)
356 def testDetermineOutputsOfQuantumNode(self):
357 testNodes = self.qGraph.getNodesForTask(self.tasks[0])
358 matchNodes = self.qGraph.getNodesForTask(self.tasks[1])
359 connections = set()
360 for node in testNodes:
361 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node))
362 self.assertEqual(matchNodes, connections)
364 def testDetermineConnectionsOfQuantum(self):
365 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
366 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2])
367 # outputs contain nodes tested for because it is a complete graph
368 matchNodes |= set(testNodes)
369 connections = set()
370 for node in testNodes:
371 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node))
372 self.assertEqual(matchNodes, connections)
374 def testDetermineAnsestorsOfQuantumNode(self):
375 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
376 matchNodes = self.qGraph.getNodesForTask(self.tasks[0])
377 matchNodes |= set(testNodes)
378 connections = set()
379 for node in testNodes:
380 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node))
381 self.assertEqual(matchNodes, connections)
383 def testFindCycle(self):
384 self.assertFalse(self.qGraph.findCycle())
386 def testSaveLoad(self):
387 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile:
388 self.qGraph.save(tmpFile)
389 tmpFile.seek(0)
390 restore = QuantumGraph.load(tmpFile, self.universe)
391 self.assertEqual(self.qGraph, restore)
392 # Load in just one node
393 tmpFile.seek(0)
394 nodeId = [n.nodeId for n in self.qGraph][0]
395 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,))
396 self.assertEqual(len(restoreSub), 1)
397 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
398 # Check that InitInput and InitOutput refs are restored correctly.
399 for taskDef in restore.iterTaskGraph():
400 if taskDef.label in ("S", "T"):
401 refs = restore.initInputRefs(taskDef)
402 self.assertIsNotNone(refs)
403 self.assertGreater(len(refs), 0)
404 if taskDef.label in ("R", "S", "T"):
405 refs = restore.initOutputRefs(taskDef)
406 self.assertIsNotNone(refs)
407 self.assertGreater(len(refs), 0)
409 # Different universes.
410 tmpFile.seek(0)
411 different_config = self.config.copy()
412 different_config["version"] = 1_000_000
413 different_universe = DimensionUniverse(config=different_config)
414 with self.assertLogs("lsst.daf.butler", "INFO"):
415 QuantumGraph.load(tmpFile, different_universe)
417 different_config["namespace"] = "incompatible"
418 different_universe = DimensionUniverse(config=different_config)
419 print("Trying with uni ", different_universe)
420 tmpFile.seek(0)
421 with self.assertRaises(RuntimeError) as cm:
422 QuantumGraph.load(tmpFile, different_universe)
423 self.assertIn("not compatible with", str(cm.exception))
425 def testSaveLoadUri(self):
426 uri = None
427 try:
428 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile:
429 uri = tmpFile.name
430 self.qGraph.saveUri(uri)
431 restore = QuantumGraph.loadUri(uri)
432 self.assertEqual(restore.metadata, METADATA)
433 self.assertEqual(self.qGraph, restore)
434 nodeNumberId = random.randint(0, len(self.qGraph) - 1)
435 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId]
436 restoreSub = QuantumGraph.loadUri(
437 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId
438 )
439 self.assertEqual(len(restoreSub), 1)
440 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber))
441 # verify that more than one node works
442 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
443 # ensure it is a different node number
444 while nodeNumberId2 == nodeNumberId:
445 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
446 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2]
447 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2))
448 self.assertEqual(len(restoreSub), 2)
449 self.assertEqual(
450 set(restoreSub),
451 set(
452 (
453 restore.getQuantumNodeByNodeId(nodeNumber),
454 restore.getQuantumNodeByNodeId(nodeNumber2),
455 )
456 ),
457 )
458 # verify an error when requesting a non existant node number
459 with self.assertRaises(ValueError):
460 QuantumGraph.loadUri(uri, self.universe, nodes=(99,))
462 # verify a graphID that does not match will be an error
463 with self.assertRaises(ValueError):
464 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT")
466 except Exception as e:
467 raise e
468 finally:
469 if uri is not None:
470 os.remove(uri)
472 with self.assertRaises(TypeError):
473 self.qGraph.saveUri("test.notgraph")
475 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!")
476 @mock_s3
477 def testSaveLoadUriS3(self):
478 # Test loading a quantum graph from an mock s3 store
479 conn = boto3.resource("s3", region_name="us-east-1")
480 conn.create_bucket(Bucket="testBucket")
481 uri = "s3://testBucket/qgraph.qgraph"
482 self.qGraph.saveUri(uri)
483 restore = QuantumGraph.loadUri(uri)
484 self.assertEqual(self.qGraph, restore)
485 nodeId = list(self.qGraph)[0].nodeId
486 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeId,))
487 self.assertEqual(len(restoreSub), 1)
488 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
490 def testContains(self):
491 firstNode = next(iter(self.qGraph))
492 self.assertIn(firstNode, self.qGraph)
494 def testDimensionUniverseInSave(self):
495 _, header = self.qGraph._buildSaveObject(returnHeader=True)
496 # type ignore because buildSaveObject does not have method overload
497 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore
500class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
501 pass
504def setup_module(module):
505 lsst.utils.tests.init()
508if __name__ == "__main__": 508 ↛ 509line 508 didn't jump to line 509, because the condition on line 508 was never true
509 lsst.utils.tests.init()
510 unittest.main()