Coverage for tests/test_graphBuilder.py: 20%
65 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:31 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:31 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Tests of things related to the GraphBuilder class."""
24import io
25import logging
26import unittest
28import lsst.utils.tests
29from lsst.daf.butler.registry import UserExpressionError
30from lsst.pipe.base import QuantumGraph
31from lsst.pipe.base.graphBuilder import DatasetQueryConstraintVariant
32from lsst.pipe.base.tests import simpleQGraph
33from lsst.utils.tests import temporaryDirectory
35_LOG = logging.getLogger(__name__)
38class GraphBuilderTestCase(unittest.TestCase):
39 """Test graph building."""
41 def _assertGraph(self, graph: QuantumGraph) -> None:
42 """Check basic structure of the graph."""
43 for taskDef in graph.iterTaskGraph():
44 refs = graph.initOutputRefs(taskDef)
45 # task has one initOutput, second ref is for config dataset
46 self.assertEqual(len(refs), 2)
48 self.assertEqual(len(list(graph.inputQuanta)), 1)
50 # This includes only "packages" dataset for now.
51 refs = graph.globalInitOutputRefs()
52 self.assertEqual(len(refs), 1)
54 def testDefault(self):
55 """Simple test to verify makeSimpleQGraph can be used to make a Quantum
56 Graph.
57 """
58 with temporaryDirectory() as root:
59 # makeSimpleQGraph calls GraphBuilder.
60 butler, qgraph = simpleQGraph.makeSimpleQGraph(root=root)
61 # by default makeSimpleQGraph makes a graph with 5 nodes
62 self.assertEqual(len(qgraph), 5)
63 self._assertGraph(qgraph)
64 constraint = DatasetQueryConstraintVariant.OFF
65 _, qgraph2 = simpleQGraph.makeSimpleQGraph(
66 butler=butler, datasetQueryConstraint=constraint, callPopulateButler=False
67 )
68 # When all outputs are random resolved refs, direct comparison
69 # of graphs does not work because IDs are different. Can only
70 # verify the number of quanta in the graph without doing something
71 # terribly complicated.
72 self.assertEqual(len(qgraph2), 5)
73 constraint = DatasetQueryConstraintVariant.fromExpression("add_dataset0")
74 _, qgraph3 = simpleQGraph.makeSimpleQGraph(
75 butler=butler, datasetQueryConstraint=constraint, callPopulateButler=False
76 )
77 self.assertEqual(len(qgraph3), 5)
79 def testAddInstrumentMismatch(self):
80 """Verify that a RuntimeError is raised if the instrument in the user
81 query does not match the instrument in the pipeline.
82 """
83 with temporaryDirectory() as root:
84 pipeline = simpleQGraph.makeSimplePipeline(
85 nQuanta=5, instrument="lsst.pipe.base.tests.simpleQGraph.SimpleInstrument"
86 )
87 with self.assertRaises(UserExpressionError):
88 simpleQGraph.makeSimpleQGraph(root=root, pipeline=pipeline, userQuery="instrument = 'foo'")
90 def testUserQueryBind(self):
91 """Verify that bind values work for user query."""
92 pipeline = simpleQGraph.makeSimplePipeline(
93 nQuanta=5, instrument="lsst.pipe.base.tests.simpleQGraph.SimpleInstrument"
94 )
95 instr = simpleQGraph.SimpleInstrument.getName()
96 # With a literal in the user query
97 with temporaryDirectory() as root:
98 simpleQGraph.makeSimpleQGraph(root=root, pipeline=pipeline, userQuery=f"instrument = '{instr}'")
99 # With a bind for the user query
100 with temporaryDirectory() as root:
101 simpleQGraph.makeSimpleQGraph(
102 root=root, pipeline=pipeline, userQuery="instrument = instr", bind={"instr": instr}
103 )
105 def test_datastore_records(self):
106 """Test for generating datastore records."""
107 with temporaryDirectory() as root:
108 # need FileDatastore for this tests
109 butler, qgraph1 = simpleQGraph.makeSimpleQGraph(
110 root=root, inMemory=False, makeDatastoreRecords=True
111 )
113 # save and reload
114 buffer = io.BytesIO()
115 qgraph1.save(buffer)
116 buffer.seek(0)
117 qgraph2 = QuantumGraph.load(buffer, universe=butler.dimensions)
118 del buffer
120 for qgraph in (qgraph1, qgraph2):
121 self.assertEqual(len(qgraph), 5)
122 for i, qnode in enumerate(qgraph):
123 quantum = qnode.quantum
124 self.assertIsNotNone(quantum.datastore_records)
125 # only the first quantum has a pre-existing input
126 if i == 0:
127 datastore_name = "FileDatastore@<butlerRoot>"
128 self.assertEqual(set(quantum.datastore_records.keys()), {datastore_name})
129 records_data = quantum.datastore_records[datastore_name]
130 records = dict(records_data.records)
131 self.assertEqual(len(records), 1)
132 _, records = records.popitem()
133 records = records["file_datastore_records"]
134 self.assertEqual(
135 [record.path for record in records],
136 ["test/add_dataset0/add_dataset0_INSTR_det0_test.pickle"],
137 )
138 else:
139 self.assertEqual(quantum.datastore_records, {})
142if __name__ == "__main__":
143 lsst.utils.tests.init()
144 unittest.main()