Coverage for tests/test_graphBuilder.py: 14%
94 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-19 02:08 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-19 02:08 -0800
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Tests of things related to the GraphBuilder class."""
24import io
25import logging
26import unittest
28import lsst.utils.tests
29from lsst.daf.butler.registry import UserExpressionError
30from lsst.pipe.base import QuantumGraph
31from lsst.pipe.base.graphBuilder import DatasetQueryConstraintVariant
32from lsst.pipe.base.tests import simpleQGraph
33from lsst.utils.tests import temporaryDirectory
35_LOG = logging.getLogger(__name__)
38class GraphBuilderTestCase(unittest.TestCase):
39 def _assertGraph(self, graph: QuantumGraph) -> None:
40 """Check basic structure of the graph."""
41 for taskDef in graph.iterTaskGraph():
42 refs = graph.initOutputRefs(taskDef)
43 # task has one initOutput, second ref is for config dataset
44 self.assertEqual(len(refs), 2)
46 self.assertEqual(len(list(graph.inputQuanta)), 1)
48 # This includes only "packages" dataset for now.
49 refs = graph.globalInitOutputRefs()
50 self.assertEqual(len(refs), 1)
52 def testDefault(self):
53 """Simple test to verify makeSimpleQGraph can be used to make a Quantum
54 Graph."""
55 with temporaryDirectory() as root:
56 # makeSimpleQGraph calls GraphBuilder.
57 butler, qgraph = simpleQGraph.makeSimpleQGraph(root=root)
58 # by default makeSimpleQGraph makes a graph with 5 nodes
59 self.assertEqual(len(qgraph), 5)
60 self._assertGraph(qgraph)
61 constraint = DatasetQueryConstraintVariant.OFF
62 _, qgraph2 = simpleQGraph.makeSimpleQGraph(
63 butler=butler, datasetQueryConstraint=constraint, callPopulateButler=False
64 )
65 self.assertEqual(len(qgraph2), 5)
66 self.assertEqual(qgraph, qgraph2)
67 constraint = DatasetQueryConstraintVariant.fromExpression("add_dataset0")
68 _, qgraph3 = simpleQGraph.makeSimpleQGraph(
69 butler=butler, datasetQueryConstraint=constraint, callPopulateButler=False
70 )
71 self.assertEqual(qgraph2, qgraph3)
73 def testAddInstrumentMismatch(self):
74 """Verify that a RuntimeError is raised if the instrument in the user
75 query does not match the instrument in the pipeline."""
76 with temporaryDirectory() as root:
77 pipeline = simpleQGraph.makeSimplePipeline(
78 nQuanta=5, instrument="lsst.pipe.base.tests.simpleQGraph.SimpleInstrument"
79 )
80 with self.assertRaises(UserExpressionError):
81 simpleQGraph.makeSimpleQGraph(root=root, pipeline=pipeline, userQuery="instrument = 'foo'")
83 def testUserQueryBind(self):
84 """Verify that bind values work for user query."""
85 pipeline = simpleQGraph.makeSimplePipeline(
86 nQuanta=5, instrument="lsst.pipe.base.tests.simpleQGraph.SimpleInstrument"
87 )
88 instr = simpleQGraph.SimpleInstrument.getName()
89 # With a literal in the user query
90 with temporaryDirectory() as root:
91 simpleQGraph.makeSimpleQGraph(root=root, pipeline=pipeline, userQuery=f"instrument = '{instr}'")
92 # With a bind for the user query
93 with temporaryDirectory() as root:
94 simpleQGraph.makeSimpleQGraph(
95 root=root, pipeline=pipeline, userQuery="instrument = instr", bind={"instr": instr}
96 )
98 def test_datastore_records(self):
99 """Test for generating datastore records."""
100 with temporaryDirectory() as root:
101 # need FileDatastore for this tests
102 butler, qgraph1 = simpleQGraph.makeSimpleQGraph(
103 root=root, inMemory=False, makeDatastoreRecords=True
104 )
106 # save and reload
107 buffer = io.BytesIO()
108 qgraph1.save(buffer)
109 buffer.seek(0)
110 qgraph2 = QuantumGraph.load(buffer, universe=butler.dimensions)
111 del buffer
113 for qgraph in (qgraph1, qgraph2):
114 self.assertEqual(len(qgraph), 5)
115 for i, qnode in enumerate(qgraph):
116 quantum = qnode.quantum
117 self.assertIsNotNone(quantum.datastore_records)
118 # only the first quantum has a pre-existing input
119 if i == 0:
120 datastore_name = "FileDatastore@<butlerRoot>"
121 self.assertEqual(set(quantum.datastore_records.keys()), {datastore_name})
122 records_data = quantum.datastore_records[datastore_name]
123 records = dict(records_data.records)
124 self.assertEqual(len(records), 1)
125 _, records = records.popitem()
126 records = records["file_datastore_records"]
127 self.assertEqual(
128 [record.path for record in records],
129 ["test/add_dataset0/add_dataset0_INSTR_det0_test.pickle"],
130 )
131 else:
132 self.assertEqual(quantum.datastore_records, {})
134 def testResolveRefs(self):
135 """Test for GraphBuilder with resolveRefs=True."""
137 def _assert_resolved(refs):
138 self.assertTrue(all(ref.id is not None for ref in refs))
140 def _assert_unresolved(refs):
141 self.assertTrue(all(ref.id is None for ref in refs))
143 for resolveRefs in (False, True):
144 with self.subTest(resolveRefs=resolveRefs):
146 assert_refs = _assert_resolved if resolveRefs else _assert_unresolved
148 with temporaryDirectory() as root:
150 _, qgraph = simpleQGraph.makeSimpleQGraph(root=root, resolveRefs=resolveRefs)
151 self.assertEqual(len(qgraph), 5)
153 # check per-quantum inputs/outputs
154 for node in qgraph:
155 quantum = node.quantum
156 for datasetType, refs in quantum.inputs.items():
157 if datasetType.name == "add_dataset0":
158 # Existing refs are always resolved
159 _assert_resolved(refs)
160 else:
161 assert_refs(refs)
162 for datasetType, refs in quantum.outputs.items():
163 assert_refs(refs)
165 # check per-task init-inputs/init-outputs
166 for taskDef in qgraph.iterTaskGraph():
167 if (refs := qgraph.initInputRefs(taskDef)) is not None:
168 assert_refs(refs)
169 if (refs := qgraph.initOutputRefs(taskDef)) is not None:
170 assert_refs(refs)
172 # check global init-outputs
173 assert_refs(qgraph.globalInitOutputRefs())
176if __name__ == "__main__": 176 ↛ 177line 176 didn't jump to line 177, because the condition on line 176 was never true
177 lsst.utils.tests.init()
178 unittest.main()