Coverage for tests/test_quantumGraph.py: 20%
312 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-04-13 10:09 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-04-13 10:09 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import os
23import pickle
24import random
25import tempfile
26import unittest
27import uuid
28from itertools import chain
29from typing import Iterable
31import lsst.pipe.base.connectionTypes as cT
32import lsst.utils.tests
33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum
34from lsst.pex.config import Field
35from lsst.pipe.base import (
36 DatasetTypeName,
37 PipelineTask,
38 PipelineTaskConfig,
39 PipelineTaskConnections,
40 QuantumGraph,
41 TaskDef,
42)
43from lsst.pipe.base.graph.quantumNode import QuantumNode
44from lsst.utils.introspection import get_full_type_name
46METADATA = {"a": [1, 2, 3]}
49class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")):
50 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
51 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
52 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
55class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections):
56 conf1 = Field(dtype=int, default=1, doc="dummy config")
59class Dummy1PipelineTask(PipelineTask):
60 ConfigClass = Dummy1Config
63class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")):
64 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a")
65 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
66 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
67 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
70class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections):
71 conf1 = Field(dtype=int, default=1, doc="dummy config")
74class Dummy2PipelineTask(PipelineTask):
75 ConfigClass = Dummy2Config
78class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")):
79 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a")
80 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a")
81 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
82 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
85class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections):
86 conf1 = Field(dtype=int, default=1, doc="dummy config")
89class Dummy3PipelineTask(PipelineTask):
90 ConfigClass = Dummy3Config
93# Test if a Task that does not interact with the other Tasks works fine in
94# the graph.
95class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")):
96 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
97 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B"))
100class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections):
101 conf1 = Field(dtype=int, default=1, doc="dummy config")
104class Dummy4PipelineTask(PipelineTask):
105 ConfigClass = Dummy4Config
108class QuantumGraphTestCase(unittest.TestCase):
109 """Tests the various functions of a quantum graph"""
111 def setUp(self):
112 self.config = Config(
113 {
114 "version": 1,
115 "namespace": "pipe_base_test",
116 "skypix": {
117 "common": "htm7",
118 "htm": {
119 "class": "lsst.sphgeom.HtmPixelization",
120 "max_level": 24,
121 },
122 },
123 "elements": {
124 "A": {
125 "keys": [
126 {
127 "name": "id",
128 "type": "int",
129 }
130 ],
131 "storage": {
132 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
133 },
134 },
135 "B": {
136 "keys": [
137 {
138 "name": "id",
139 "type": "int",
140 }
141 ],
142 "storage": {
143 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
144 },
145 },
146 },
147 "packers": {},
148 }
149 )
150 universe = DimensionUniverse(config=self.config)
152 def _makeDatasetType(connection):
153 return DatasetType(
154 connection.name,
155 getattr(connection, "dimensions", ()),
156 storageClass=connection.storageClass,
157 universe=universe,
158 )
160 # need to make a mapping of TaskDef to set of quantum
161 quantumMap = {}
162 tasks = []
163 initInputs = {}
164 initOutputs = {}
165 dataset_types = set()
166 for task, label in (
167 (Dummy1PipelineTask, "R"),
168 (Dummy2PipelineTask, "S"),
169 (Dummy3PipelineTask, "T"),
170 (Dummy4PipelineTask, "U"),
171 ):
172 config = task.ConfigClass()
173 taskDef = TaskDef(get_full_type_name(task), config, task, label)
174 tasks.append(taskDef)
175 quantumSet = set()
176 connections = taskDef.connections
177 if connections.initInputs:
178 initInputDSType = _makeDatasetType(connections.initInput)
179 initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))]
180 initInputs[taskDef] = initRefs
181 dataset_types.add(initInputDSType)
182 else:
183 initRefs = None
184 if connections.initOutputs:
185 initOutputDSType = _makeDatasetType(connections.initOutput)
186 initRefs = [DatasetRef(initOutputDSType, DataCoordinate.makeEmpty(universe))]
187 initOutputs[taskDef] = initRefs
188 dataset_types.add(initOutputDSType)
189 inputDSType = _makeDatasetType(connections.input)
190 dataset_types.add(inputDSType)
191 outputDSType = _makeDatasetType(connections.output)
192 dataset_types.add(outputDSType)
193 for a, b in ((1, 2), (3, 4)):
194 inputRefs = [
195 DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
196 ]
197 outputRefs = [
198 DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
199 ]
200 quantumSet.add(
201 Quantum(
202 taskName=task.__qualname__,
203 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
204 taskClass=task,
205 initInputs=initRefs,
206 inputs={inputDSType: inputRefs},
207 outputs={outputDSType: outputRefs},
208 )
209 )
210 quantumMap[taskDef] = quantumSet
211 self.tasks = tasks
212 self.quantumMap = quantumMap
213 self.packagesDSType = DatasetType("packages", universe.empty, storageClass="Packages")
214 dataset_types.add(self.packagesDSType)
215 globalInitOutputs = [DatasetRef(self.packagesDSType, DataCoordinate.makeEmpty(universe))]
216 self.qGraph = QuantumGraph(
217 quantumMap,
218 metadata=METADATA,
219 universe=universe,
220 initInputs=initInputs,
221 initOutputs=initOutputs,
222 globalInitOutputs=globalInitOutputs,
223 registryDatasetTypes=dataset_types,
224 )
225 self.universe = universe
226 self.num_dataset_types = len(dataset_types)
228 def testTaskGraph(self):
229 for taskDef in self.quantumMap.keys():
230 self.assertIn(taskDef, self.qGraph.taskGraph)
232 def testGraph(self):
233 graphSet = {q.quantum for q in self.qGraph.graph}
234 for quantum in chain.from_iterable(self.quantumMap.values()):
235 self.assertIn(quantum, graphSet)
237 def testGetQuantumNodeByNodeId(self):
238 inputQuanta = tuple(self.qGraph.inputQuanta)
239 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
240 self.assertEqual(node, inputQuanta[0])
241 wrongNode = uuid.uuid4()
242 with self.assertRaises(KeyError):
243 self.qGraph.getQuantumNodeByNodeId(wrongNode)
245 def testPickle(self):
246 stringify = pickle.dumps(self.qGraph)
247 restore: QuantumGraph = pickle.loads(stringify)
248 self.assertEqual(self.qGraph, restore)
250 def testInputQuanta(self):
251 inputs = {q.quantum for q in self.qGraph.inputQuanta}
252 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs)
254 def testOutputQuanta(self):
255 outputs = {q.quantum for q in self.qGraph.outputQuanta}
256 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs)
258 def testLength(self):
259 self.assertEqual(len(self.qGraph), 2 * len(self.tasks))
261 def testGetQuantaForTask(self):
262 for task in self.tasks:
263 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])
265 def testGetNumberOfQuantaForTask(self):
266 for task in self.tasks:
267 self.assertEqual(self.qGraph.getNumberOfQuantaForTask(task), len(self.quantumMap[task]))
269 def testGetNodesForTask(self):
270 for task in self.tasks:
271 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task)
272 quanta_in_node = set(n.quantum for n in nodes)
273 self.assertEqual(quanta_in_node, self.quantumMap[task])
275 def testFindTasksWithInput(self):
276 self.assertEqual(
277 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1]
278 )
280 def testFindTasksWithOutput(self):
281 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])
283 def testTaskWithDSType(self):
284 self.assertEqual(
285 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2])
286 )
288 def testFindTaskDefByName(self):
289 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0])
291 def testFindTaskDefByLabel(self):
292 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0])
294 def testFindQuantaWIthDSType(self):
295 self.assertEqual(
296 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]]
297 )
299 def testAllDatasetTypes(self):
300 allDatasetTypes = set(self.qGraph.allDatasetTypes)
301 truth = set()
302 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections):
303 for connection in conClass.allConnections.values(): # type: ignore
304 if not isinstance(connection, cT.InitOutput):
305 truth.add(connection.name)
306 self.assertEqual(allDatasetTypes, truth)
308 def testSubset(self):
309 allNodes = list(self.qGraph)
310 firstNode = allNodes[0]
311 subset = self.qGraph.subset(firstNode)
312 self.assertEqual(len(subset), 1)
313 subsetList = list(subset)
314 self.assertEqual(firstNode.quantum, subsetList[0].quantum)
315 self.assertEqual(self.qGraph._buildId, subset._buildId)
316 self.assertEqual(len(subset.globalInitOutputRefs()), 1)
317 # Depending on which task was first the list can contain different
318 # number of datasets. The first task can be either Dummy1 or Dummy4.
319 num_types = {"R": 4, "U": 3}
320 self.assertEqual(len(subset.registryDatasetTypes()), num_types[firstNode.taskDef.label])
322 def testSubsetToConnected(self):
323 # False because there are two quantum chains for two distinct sets of
324 # dimensions
325 self.assertFalse(self.qGraph.isConnected)
327 connectedGraphs = self.qGraph.subsetToConnected()
328 self.assertEqual(len(connectedGraphs), 4)
329 self.assertTrue(connectedGraphs[0].isConnected)
330 self.assertTrue(connectedGraphs[1].isConnected)
331 self.assertTrue(connectedGraphs[2].isConnected)
332 self.assertTrue(connectedGraphs[3].isConnected)
334 # Split out task[3] because it is expected to be on its own
335 for cg in connectedGraphs:
336 if self.tasks[3] in cg.taskGraph:
337 self.assertEqual(len(cg), 1)
338 else:
339 self.assertEqual(len(cg), 3)
341 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])
343 count = 0
344 for node in self.qGraph:
345 if connectedGraphs[0].checkQuantumInGraph(node.quantum):
346 count += 1
347 if connectedGraphs[1].checkQuantumInGraph(node.quantum):
348 count += 1
349 if connectedGraphs[2].checkQuantumInGraph(node.quantum):
350 count += 1
351 if connectedGraphs[3].checkQuantumInGraph(node.quantum):
352 count += 1
353 self.assertEqual(len(self.qGraph), count)
355 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs}
356 for setLen, tskSet in taskSets.items():
357 if setLen == 3:
358 self.assertEqual(set(self.tasks[:-1]), tskSet)
359 elif setLen == 1:
360 self.assertEqual({self.tasks[-1]}, tskSet)
361 for cg in connectedGraphs:
362 if len(cg.taskGraph) == 1:
363 continue
364 allNodes = list(cg)
365 node = cg.determineInputsToQuantumNode(allNodes[1])
366 self.assertEqual(set([allNodes[0]]), node)
367 node = cg.determineInputsToQuantumNode(allNodes[1])
368 self.assertEqual(set([allNodes[0]]), node)
370 def testDetermineOutputsOfQuantumNode(self):
371 testNodes = self.qGraph.getNodesForTask(self.tasks[0])
372 matchNodes = self.qGraph.getNodesForTask(self.tasks[1])
373 connections = set()
374 for node in testNodes:
375 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node))
376 self.assertEqual(matchNodes, connections)
378 def testDetermineConnectionsOfQuantum(self):
379 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
380 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2])
381 # outputs contain nodes tested for because it is a complete graph
382 matchNodes |= set(testNodes)
383 connections = set()
384 for node in testNodes:
385 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node))
386 self.assertEqual(matchNodes, connections)
388 def testDetermineAnsestorsOfQuantumNode(self):
389 testNodes = self.qGraph.getNodesForTask(self.tasks[1])
390 matchNodes = self.qGraph.getNodesForTask(self.tasks[0])
391 matchNodes |= set(testNodes)
392 connections = set()
393 for node in testNodes:
394 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node))
395 self.assertEqual(matchNodes, connections)
397 def testFindCycle(self):
398 self.assertFalse(self.qGraph.findCycle())
400 def testSaveLoad(self):
401 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile:
402 self.qGraph.save(tmpFile)
403 tmpFile.seek(0)
404 restore = QuantumGraph.load(tmpFile, self.universe)
405 self.assertEqual(self.qGraph, restore)
406 # Load in just one node
407 tmpFile.seek(0)
408 nodeId = [n.nodeId for n in self.qGraph][0]
409 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,))
410 self.assertEqual(len(restoreSub), 1)
411 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
412 self.assertEqual(len(restoreSub.globalInitOutputRefs()), 1)
413 self.assertEqual(len(restoreSub.registryDatasetTypes()), self.num_dataset_types)
414 # Check that InitInput and InitOutput refs are restored correctly.
415 for taskDef in restore.iterTaskGraph():
416 if taskDef.label in ("S", "T"):
417 refs = restore.initInputRefs(taskDef)
418 self.assertIsNotNone(refs)
419 self.assertGreater(len(refs), 0)
420 if taskDef.label in ("R", "S", "T"):
421 refs = restore.initOutputRefs(taskDef)
422 self.assertIsNotNone(refs)
423 self.assertGreater(len(refs), 0)
425 # Different universes.
426 tmpFile.seek(0)
427 different_config = self.config.copy()
428 different_config["version"] = 1_000_000
429 different_universe = DimensionUniverse(config=different_config)
430 with self.assertLogs("lsst.daf.butler", "INFO"):
431 QuantumGraph.load(tmpFile, different_universe)
433 different_config["namespace"] = "incompatible"
434 different_universe = DimensionUniverse(config=different_config)
435 print("Trying with uni ", different_universe)
436 tmpFile.seek(0)
437 with self.assertRaises(RuntimeError) as cm:
438 QuantumGraph.load(tmpFile, different_universe)
439 self.assertIn("not compatible with", str(cm.exception))
441 def testSaveLoadUri(self):
442 uri = None
443 try:
444 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile:
445 uri = tmpFile.name
446 self.qGraph.saveUri(uri)
447 restore = QuantumGraph.loadUri(uri)
448 self.assertEqual(restore.metadata, METADATA)
449 self.assertEqual(self.qGraph, restore)
450 nodeNumberId = random.randint(0, len(self.qGraph) - 1)
451 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId]
452 restoreSub = QuantumGraph.loadUri(
453 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId
454 )
455 self.assertEqual(len(restoreSub), 1)
456 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber))
457 # verify that more than one node works
458 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
459 # ensure it is a different node number
460 while nodeNumberId2 == nodeNumberId:
461 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1)
462 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2]
463 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2))
464 self.assertEqual(len(restoreSub), 2)
465 self.assertEqual(
466 set(restoreSub),
467 set(
468 (
469 restore.getQuantumNodeByNodeId(nodeNumber),
470 restore.getQuantumNodeByNodeId(nodeNumber2),
471 )
472 ),
473 )
474 # verify an error when requesting a non existant node number
475 with self.assertRaises(ValueError):
476 QuantumGraph.loadUri(uri, self.universe, nodes=(99,))
478 # verify a graphID that does not match will be an error
479 with self.assertRaises(ValueError):
480 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT")
482 except Exception as e:
483 raise e
484 finally:
485 if uri is not None:
486 os.remove(uri)
488 with self.assertRaises(TypeError):
489 self.qGraph.saveUri("test.notgraph")
491 def testSaveLoadNoRegistryDatasetTypes(self):
492 """Test for reading quantum that is missing registry dataset types.
494 This test depends on internals of QuantumGraph implementation, in
495 particular that empty list of registry dataset types is not stored,
496 which makes save file identical to the "old" format.
497 """
498 # Reset the list, this is safe as QuantumGraph itself does not use it.
499 self.qGraph._registryDatasetTypes = []
500 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile:
501 self.qGraph.save(tmpFile)
502 tmpFile.seek(0)
503 restore = QuantumGraph.load(tmpFile, self.universe)
504 self.assertEqual(self.qGraph, restore)
505 self.assertEqual(restore.registryDatasetTypes(), [])
507 def testContains(self):
508 firstNode = next(iter(self.qGraph))
509 self.assertIn(firstNode, self.qGraph)
511 def testDimensionUniverseInSave(self):
512 _, header = self.qGraph._buildSaveObject(returnHeader=True)
513 # type ignore because buildSaveObject does not have method overload
514 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore
517class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
518 pass
521def setup_module(module):
522 lsst.utils.tests.init()
525if __name__ == "__main__": 525 ↛ 526line 525 didn't jump to line 526, because the condition on line 525 was never true
526 lsst.utils.tests.init()
527 unittest.main()