Coverage for tests/test_quantumGraph.py: 21%
300 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-02 14:36 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-02 14:36 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import os
23import pickle
24import random
25import tempfile
26import unittest
27import uuid
28from itertools import chain
29from typing import Iterable
31import lsst.pipe.base.connectionTypes as cT
32import lsst.utils.tests
33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum
34from lsst.pex.config import Field
35from lsst.pipe.base import (
36 DatasetTypeName,
37 PipelineTask,
38 PipelineTaskConfig,
39 PipelineTaskConnections,
40 QuantumGraph,
41 TaskDef,
42)
43from lsst.pipe.base.graph.quantumNode import QuantumNode
44from lsst.utils.introspection import get_full_type_name
46try:
47 import boto3
48 from moto import mock_s3
49except ImportError:
50 boto3 = None
52 def mock_s3(cls):
53 """A no-op decorator in case moto mock_s3 can not be imported."""
54 return cls
57METADATA = {"a": [1, 2, 3]}
60class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
61 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
62 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
63 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
66class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
67 conf1 = Field(dtype=int, default=1, doc="dummy config")
70class Dummy1PipelineTask(PipelineTask):
71 ConfigClass = Dummy1Config
74class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
75 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
76 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
77 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
78 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
81class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
82 conf1 = Field(dtype=int, default=1, doc="dummy config")
85class Dummy2PipelineTask(PipelineTask):
86 ConfigClass = Dummy2Config
89class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
90 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
91 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a")
92 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
93 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
96class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
97 conf1 = Field(dtype=int, default=1, doc="dummy config")
100class Dummy3PipelineTask(PipelineTask):
101 ConfigClass = Dummy3Config
104# Test if a Task that does not interact with the other Tasks works fine in
105# the graph.
106class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")):
107 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
108 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
111class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections):
112 conf1 = Field(dtype=int, default=1, doc="dummy config")
115class Dummy4PipelineTask(PipelineTask):
116 ConfigClass = Dummy4Config
119class QuantumGraphTestCase(unittest.TestCase):
120 """Tests the various functions of a quantum graph"""
122 def setUp(self):
123 self.config = Config(
124 {
125 "version": 1,
126 "namespace": "pipe_base_test",
127 "skypix": {
128 "common": "htm7",
129 "htm": {
130 "class": "lsst.sphgeom.HtmPixelization",
131 "max_level": 24,
132 },
133 },
134 "elements": {
135 "A": {
136 "keys": [
137 {
138 "name": "id",
139 "type": "int",
140 }
141 ],
142 "storage": {
143 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
144 },
145 },
146 "B": {
147 "keys": [
148 {
149 "name": "id",
150 "type": "int",
151 }
152 ],
153 "storage": {
154 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
155 },
156 },
157 },
158 "packers": {},
159 }
160 )
161 universe = DimensionUniverse(config=self.config)
163 def _makeDatasetType(connection):
164 return DatasetType(
165 connection.name,
166 getattr(connection, "dimensions", ()),
167 storageClass=connection.storageClass,
168 universe=universe,
169 )
171 # need to make a mapping of TaskDef to set of quantum
172 quantumMap = {}
173 tasks = []
174 initInputs = {}
175 initOutputs = {}
176 for task, label in (
177 (Dummy1PipelineTask, "R"),
178 (Dummy2PipelineTask, "S"),
179 (Dummy3PipelineTask, "T"),
180 (Dummy4PipelineTask, "U"),
181 ):
182 config = task.ConfigClass()
183 taskDef = TaskDef(get_full_type_name(task), config, task, label)
184 tasks.append(taskDef)
185 quantumSet = set()
186 connections = taskDef.connections
187 if connections.initInputs:
188 initInputDSType = _makeDatasetType(connections.initInput)
189 initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))]
190 initInputs[taskDef] = initRefs
191 else:
192 initRefs = None
193 if connections.initOutputs:
194 initOutputDSType = _makeDatasetType(connections.initOutput)
195 initRefs = [DatasetRef(initOutputDSType, DataCoordinate.makeEmpty(universe))]
196 initOutputs[taskDef] = initRefs
197 inputDSType = _makeDatasetType(connections.input)
198 outputDSType = _makeDatasetType(connections.output)
199 for a, b in ((1, 2), (3, 4)):
200 inputRefs = [
201 DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
202 ]
203 outputRefs = [
204 DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
205 ]
206 quantumSet.add(
207 Quantum(
208 taskName=task.__qualname__,
209 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
210 taskClass=task,
211 initInputs=initRefs,
212 inputs={inputDSType: inputRefs},
213 outputs={outputDSType: outputRefs},
214 )
215 )
216 quantumMap[taskDef] = quantumSet
217 self.tasks = tasks
218 self.quantumMap = quantumMap
219 self.packagesDSType = DatasetType("packages", universe.empty, storageClass="Packages")
220 globalInitOutputs = [DatasetRef(self.packagesDSType, DataCoordinate.makeEmpty(universe))]
221 self.qGraph = QuantumGraph(
222 quantumMap,
223 metadata=METADATA,
224 universe=universe,
225 initInputs=initInputs,
226 initOutputs=initOutputs,
227 globalInitOutputs=globalInitOutputs,
228 )
229 self.universe = universe
231 def testTaskGraph(self):
232 for taskDef in self.quantumMap.keys():
233 self.assertIn(taskDef, self.qGraph.taskGraph)
235 def testGraph(self):
236 graphSet = {q.quantum for q in self.qGraph.graph}
237 for quantum in chain.from_iterable(self.quantumMap.values()):
238 self.assertIn(quantum, graphSet)
240 def testGetQuantumNodeByNodeId(self):
241 inputQuanta = tuple(self.qGraph.inputQuanta)
242 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
243 self.assertEqual(node, inputQuanta[0])
244 wrongNode = uuid.uuid4()
245 with self.assertRaises(KeyError):
246 self.qGraph.getQuantumNodeByNodeId(wrongNode)
248 def testPickle(self):
249 stringify = pickle.dumps(self.qGraph)
250 restore: QuantumGraph = pickle.loads(stringify)
251 self.assertEqual(self.qGraph, restore)
253 def testInputQuanta(self):
254 inputs = {q.quantum for q in self.qGraph.inputQuanta}
255 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs)
257 def testOutputQuanta(self):
258 outputs = {q.quantum for q in self.qGraph.outputQuanta}
259 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs)
261 def testLength(self):
262 self.assertEqual(len(self.qGraph), 2 * len(self.tasks))
264 def testGetQuantaForTask(self):
265 for task in self.tasks:
266 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
268 def testGetNumberOfQuantaForTask(self):
269 for task in self.tasks:
270 self.assertEqual(self.qGraph.getNumberOfQuantaForTask(task), len(self.quantumMap[task]))
272 def testGetNodesForTask(self):
273 for task in self.tasks:
274 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task)
275 quanta_in_node = set(n.quantum for n in nodes)
276 self.assertEqual(quanta_in_node, self.quantumMap[task])
278 def testFindTasksWithInput(self):
279 self.assertEqual(
280 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1]
281 )
283 def testFindTasksWithOutput(self):
284 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
286 def testTaskWithDSType(self):
287 self.assertEqual(
288 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2])
289 )
291 def testFindTaskDefByName(self):
292 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0])
294 def testFindTaskDefByLabel(self):
295 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0])
297 def testFindQuantaWIthDSType(self):
298 self.assertEqual(
299 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]]
300 )
302 def testAllDatasetTypes(self):
303 allDatasetTypes = set(self.qGraph.allDatasetTypes)
304 truth = set()
305 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections):
306 for connection in conClass.allConnections.values(): # type: ignore
307 if not isinstance(connection, cT.InitOutput):
308 truth.add(connection.name)
309 self.assertEqual(allDatasetTypes, truth)
311 def testSubset(self):
312 allNodes = list(self.qGraph)
313 subset = self.qGraph.subset(allNodes[0])
314 self.assertEqual(len(subset), 1)
315 subsetList = list(subset)
316 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
317 self.assertEqual(self.qGraph._buildId, subset._buildId)
318 self.assertEqual(len(subset.globalInitOutputRefs()), 1)
320 def testSubsetToConnected(self):
321 # False because there are two quantum chains for two distinct sets of
322 # dimensions
323 self.assertFalse(self.qGraph.isConnected)
325 connectedGraphs = self.qGraph.subsetToConnected()
326 self.assertEqual(len(connectedGraphs), 4)
327 self.assertTrue(connectedGraphs[0].isConnected)
328 self.assertTrue(connectedGraphs[1].isConnected)
329 self.assertTrue(connectedGraphs[2].isConnected)
330 self.assertTrue(connectedGraphs[3].isConnected)
332 # Split out task[3] because it is expected to be on its own
333 for cg in connectedGraphs:
334 if self.tasks[3] in cg.taskGraph:
335 self.assertEqual(len(cg), 1)
336 else:
337 self.assertEqual(len(cg), 3)
339 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
341 count = 0
342 for node in self.qGraph:
343 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
344 count += 1
345 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
346 count += 1
347 if connectedGraphs[2].checkQuantumInGraph(node.quantum):
348 count += 1
349 if connectedGraphs[3].checkQuantumInGraph(node.quantum):
350 count += 1
351 self.assertEqual(len(self.qGraph), count)
353 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs}
354 for setLen, tskSet in taskSets.items():
355 if setLen == 3:
356 self.assertEqual(set(self.tasks[:-1]), tskSet)
357 elif setLen == 1:
358 self.assertEqual({self.tasks[-1]}, tskSet)
359 for cg in connectedGraphs:
360 if len(cg.taskGraph) == 1:
361 continue
362 allNodes = list(cg)
363 node = cg.determineInputsToQuantumNode(allNodes[1])
364 self.assertEqual(set([allNodes[0]]), node)
365 node = cg.determineInputsToQuantumNode(allNodes[1])
366 self.assertEqual(set([allNodes[0]]), node)
368 def testDetermineOutputsOfQuantumNode(self):
369 testNodes = self.qGraph.getNodesForTask(self.tasks[0])
370 matchNodes = self.qGraph.getNodesForTask(self.tasks[1])
371 connections = set()
372 for node in testNodes:
373 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node))
374 self.assertEqual(matchNodes, connections)
376 def testDetermineConnectionsOfQuantum(self):
377 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
378 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2])
379 # outputs contain nodes tested for because it is a complete graph
380 matchNodes |= set(testNodes)
381 connections = set()
382 for node in testNodes:
383 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node))
384 self.assertEqual(matchNodes, connections)
386 def testDetermineAnsestorsOfQuantumNode(self):
387 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
388 matchNodes = self.qGraph.getNodesForTask(self.tasks[0])
389 matchNodes |= set(testNodes)
390 connections = set()
391 for node in testNodes:
392 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node))
393 self.assertEqual(matchNodes, connections)
395 def testFindCycle(self):
396 self.assertFalse(self.qGraph.findCycle())
398 def testSaveLoad(self):
399 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile:
400 self.qGraph.save(tmpFile)
401 tmpFile.seek(0)
402 restore = QuantumGraph.load(tmpFile, self.universe)
403 self.assertEqual(self.qGraph, restore)
404 # Load in just one node
405 tmpFile.seek(0)
406 nodeId = [n.nodeId for n in self.qGraph][0]
407 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,))
408 self.assertEqual(len(restoreSub), 1)
409 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
410 self.assertEqual(len(restoreSub.globalInitOutputRefs()), 1)
411 # Check that InitInput and InitOutput refs are restored correctly.
412 for taskDef in restore.iterTaskGraph():
413 if taskDef.label in ("S", "T"):
414 refs = restore.initInputRefs(taskDef)
415 self.assertIsNotNone(refs)
416 self.assertGreater(len(refs), 0)
417 if taskDef.label in ("R", "S", "T"):
418 refs = restore.initOutputRefs(taskDef)
419 self.assertIsNotNone(refs)
420 self.assertGreater(len(refs), 0)
422 # Different universes.
423 tmpFile.seek(0)
424 different_config = self.config.copy()
425 different_config["version"] = 1_000_000
426 different_universe = DimensionUniverse(config=different_config)
427 with self.assertLogs("lsst.daf.butler", "INFO"):
428 QuantumGraph.load(tmpFile, different_universe)
430 different_config["namespace"] = "incompatible"
431 different_universe = DimensionUniverse(config=different_config)
432 print("Trying with uni ", different_universe)
433 tmpFile.seek(0)
434 with self.assertRaises(RuntimeError) as cm:
435 QuantumGraph.load(tmpFile, different_universe)
436 self.assertIn("not compatible with", str(cm.exception))
438 def testSaveLoadUri(self):
439 uri = None
440 try:
441 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile:
442 uri = tmpFile.name
443 self.qGraph.saveUri(uri)
444 restore = QuantumGraph.loadUri(uri)
445 self.assertEqual(restore.metadata, METADATA)
446 self.assertEqual(self.qGraph, restore)
447 nodeNumberId = random.randint(0, len(self.qGraph) - 1)
448 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId]
449 restoreSub = QuantumGraph.loadUri(
450 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId
451 )
452 self.assertEqual(len(restoreSub), 1)
453 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber))
454 # verify that more than one node works
455 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
456 # ensure it is a different node number
457 while nodeNumberId2 == nodeNumberId:
458 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
459 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2]
460 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2))
461 self.assertEqual(len(restoreSub), 2)
462 self.assertEqual(
463 set(restoreSub),
464 set(
465 (
466 restore.getQuantumNodeByNodeId(nodeNumber),
467 restore.getQuantumNodeByNodeId(nodeNumber2),
468 )
469 ),
470 )
471 # verify an error when requesting a non existant node number
472 with self.assertRaises(ValueError):
473 QuantumGraph.loadUri(uri, self.universe, nodes=(99,))
475 # verify a graphID that does not match will be an error
476 with self.assertRaises(ValueError):
477 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT")
479 except Exception as e:
480 raise e
481 finally:
482 if uri is not None:
483 os.remove(uri)
485 with self.assertRaises(TypeError):
486 self.qGraph.saveUri("test.notgraph")
488 def testContains(self):
489 firstNode = next(iter(self.qGraph))
490 self.assertIn(firstNode, self.qGraph)
492 def testDimensionUniverseInSave(self):
493 _, header = self.qGraph._buildSaveObject(returnHeader=True)
494 # type ignore because buildSaveObject does not have method overload
495 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore
498class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
499 pass
502def setup_module(module):
503 lsst.utils.tests.init()
506if __name__ == "__main__": 506 ↛ 507line 506 didn't jump to line 507, because the condition on line 506 was never true
507 lsst.utils.tests.init()
508 unittest.main()