Coverage for tests/test_graphBuilder.py: 20%
65 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-23 10:43 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-23 10:43 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28"""Tests of things related to the GraphBuilder class."""
30import io
31import logging
32import unittest
34import lsst.utils.tests
35from lsst.daf.butler.registry import UserExpressionError
36from lsst.pipe.base import QuantumGraph
37from lsst.pipe.base.graphBuilder import DatasetQueryConstraintVariant
38from lsst.pipe.base.tests import simpleQGraph
39from lsst.utils.tests import temporaryDirectory
41_LOG = logging.getLogger(__name__)
44class GraphBuilderTestCase(unittest.TestCase):
45 """Test graph building."""
47 def _assertGraph(self, graph: QuantumGraph) -> None:
48 """Check basic structure of the graph."""
49 for taskDef in graph.iterTaskGraph():
50 refs = graph.initOutputRefs(taskDef)
51 # task has one initOutput, second ref is for config dataset
52 self.assertEqual(len(refs), 2)
54 self.assertEqual(len(list(graph.inputQuanta)), 1)
56 # This includes only "packages" dataset for now.
57 refs = graph.globalInitOutputRefs()
58 self.assertEqual(len(refs), 1)
60 def testDefault(self):
61 """Simple test to verify makeSimpleQGraph can be used to make a Quantum
62 Graph.
63 """
64 with temporaryDirectory() as root:
65 # makeSimpleQGraph calls GraphBuilder.
66 butler, qgraph = simpleQGraph.makeSimpleQGraph(root=root)
67 # by default makeSimpleQGraph makes a graph with 5 nodes
68 self.assertEqual(len(qgraph), 5)
69 self._assertGraph(qgraph)
70 constraint = DatasetQueryConstraintVariant.OFF
71 _, qgraph2 = simpleQGraph.makeSimpleQGraph(
72 butler=butler, datasetQueryConstraint=constraint, callPopulateButler=False
73 )
74 # When all outputs are random resolved refs, direct comparison
75 # of graphs does not work because IDs are different. Can only
76 # verify the number of quanta in the graph without doing something
77 # terribly complicated.
78 self.assertEqual(len(qgraph2), 5)
79 constraint = DatasetQueryConstraintVariant.fromExpression("add_dataset0")
80 _, qgraph3 = simpleQGraph.makeSimpleQGraph(
81 butler=butler, datasetQueryConstraint=constraint, callPopulateButler=False
82 )
83 self.assertEqual(len(qgraph3), 5)
85 def testAddInstrumentMismatch(self):
86 """Verify that a RuntimeError is raised if the instrument in the user
87 query does not match the instrument in the pipeline.
88 """
89 with temporaryDirectory() as root:
90 pipeline = simpleQGraph.makeSimplePipeline(
91 nQuanta=5, instrument="lsst.pipe.base.tests.simpleQGraph.SimpleInstrument"
92 )
93 with self.assertRaises(UserExpressionError):
94 simpleQGraph.makeSimpleQGraph(root=root, pipeline=pipeline, userQuery="instrument = 'foo'")
96 def testUserQueryBind(self):
97 """Verify that bind values work for user query."""
98 pipeline = simpleQGraph.makeSimplePipeline(
99 nQuanta=5, instrument="lsst.pipe.base.tests.simpleQGraph.SimpleInstrument"
100 )
101 instr = simpleQGraph.SimpleInstrument.getName()
102 # With a literal in the user query
103 with temporaryDirectory() as root:
104 simpleQGraph.makeSimpleQGraph(root=root, pipeline=pipeline, userQuery=f"instrument = '{instr}'")
105 # With a bind for the user query
106 with temporaryDirectory() as root:
107 simpleQGraph.makeSimpleQGraph(
108 root=root, pipeline=pipeline, userQuery="instrument = instr", bind={"instr": instr}
109 )
111 def test_datastore_records(self):
112 """Test for generating datastore records."""
113 with temporaryDirectory() as root:
114 # need FileDatastore for this tests
115 butler, qgraph1 = simpleQGraph.makeSimpleQGraph(
116 root=root, inMemory=False, makeDatastoreRecords=True
117 )
119 # save and reload
120 buffer = io.BytesIO()
121 qgraph1.save(buffer)
122 buffer.seek(0)
123 qgraph2 = QuantumGraph.load(buffer, universe=butler.dimensions)
124 del buffer
126 for qgraph in (qgraph1, qgraph2):
127 self.assertEqual(len(qgraph), 5)
128 for i, qnode in enumerate(qgraph):
129 quantum = qnode.quantum
130 self.assertIsNotNone(quantum.datastore_records)
131 # only the first quantum has a pre-existing input
132 if i == 0:
133 datastore_name = "FileDatastore@<butlerRoot>"
134 self.assertEqual(set(quantum.datastore_records.keys()), {datastore_name})
135 records_data = quantum.datastore_records[datastore_name]
136 records = dict(records_data.records)
137 self.assertEqual(len(records), 1)
138 _, records = records.popitem()
139 records = records["file_datastore_records"]
140 self.assertEqual(
141 [record.path for record in records],
142 ["test/add_dataset0/add_dataset0_INSTR_det0_test.pickle"],
143 )
144 else:
145 self.assertEqual(quantum.datastore_records, {})
148if __name__ == "__main__":
149 lsst.utils.tests.init()
150 unittest.main()