Coverage for tests/test_graphBuilder.py: 14%

94 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-05 02:05 -0800

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Tests of things related to the GraphBuilder class.""" 

23 

24import io 

25import logging 

26import unittest 

27 

28import lsst.utils.tests 

29from lsst.daf.butler.registry import UserExpressionError 

30from lsst.pipe.base import QuantumGraph 

31from lsst.pipe.base.graphBuilder import DatasetQueryConstraintVariant 

32from lsst.pipe.base.tests import simpleQGraph 

33from lsst.utils.tests import temporaryDirectory 

34 

35_LOG = logging.getLogger(__name__) 

36 

37 

38class GraphBuilderTestCase(unittest.TestCase): 

39 def _assertGraph(self, graph: QuantumGraph) -> None: 

40 """Check basic structure of the graph.""" 

41 for taskDef in graph.iterTaskGraph(): 

42 refs = graph.initOutputRefs(taskDef) 

43 # task has one initOutput, second ref is for config dataset 

44 self.assertEqual(len(refs), 2) 

45 

46 self.assertEqual(len(list(graph.inputQuanta)), 1) 

47 

48 # This includes only "packages" dataset for now. 

49 refs = graph.globalInitOutputRefs() 

50 self.assertEqual(len(refs), 1) 

51 

52 def testDefault(self): 

53 """Simple test to verify makeSimpleQGraph can be used to make a Quantum 

54 Graph.""" 

55 with temporaryDirectory() as root: 

56 # makeSimpleQGraph calls GraphBuilder. 

57 butler, qgraph = simpleQGraph.makeSimpleQGraph(root=root) 

58 # by default makeSimpleQGraph makes a graph with 5 nodes 

59 self.assertEqual(len(qgraph), 5) 

60 self._assertGraph(qgraph) 

61 constraint = DatasetQueryConstraintVariant.OFF 

62 _, qgraph2 = simpleQGraph.makeSimpleQGraph( 

63 butler=butler, datasetQueryConstraint=constraint, callPopulateButler=False 

64 ) 

65 self.assertEqual(len(qgraph2), 5) 

66 self.assertEqual(qgraph, qgraph2) 

67 constraint = DatasetQueryConstraintVariant.fromExpression("add_dataset0") 

68 _, qgraph3 = simpleQGraph.makeSimpleQGraph( 

69 butler=butler, datasetQueryConstraint=constraint, callPopulateButler=False 

70 ) 

71 self.assertEqual(qgraph2, qgraph3) 

72 

73 def testAddInstrumentMismatch(self): 

74 """Verify that a RuntimeError is raised if the instrument in the user 

75 query does not match the instrument in the pipeline.""" 

76 with temporaryDirectory() as root: 

77 pipeline = simpleQGraph.makeSimplePipeline( 

78 nQuanta=5, instrument="lsst.pipe.base.tests.simpleQGraph.SimpleInstrument" 

79 ) 

80 with self.assertRaises(UserExpressionError): 

81 simpleQGraph.makeSimpleQGraph(root=root, pipeline=pipeline, userQuery="instrument = 'foo'") 

82 

83 def testUserQueryBind(self): 

84 """Verify that bind values work for user query.""" 

85 pipeline = simpleQGraph.makeSimplePipeline( 

86 nQuanta=5, instrument="lsst.pipe.base.tests.simpleQGraph.SimpleInstrument" 

87 ) 

88 instr = simpleQGraph.SimpleInstrument.getName() 

89 # With a literal in the user query 

90 with temporaryDirectory() as root: 

91 simpleQGraph.makeSimpleQGraph(root=root, pipeline=pipeline, userQuery=f"instrument = '{instr}'") 

92 # With a bind for the user query 

93 with temporaryDirectory() as root: 

94 simpleQGraph.makeSimpleQGraph( 

95 root=root, pipeline=pipeline, userQuery="instrument = instr", bind={"instr": instr} 

96 ) 

97 

98 def test_datastore_records(self): 

99 """Test for generating datastore records.""" 

100 with temporaryDirectory() as root: 

101 # need FileDatastore for this tests 

102 butler, qgraph1 = simpleQGraph.makeSimpleQGraph( 

103 root=root, inMemory=False, makeDatastoreRecords=True 

104 ) 

105 

106 # save and reload 

107 buffer = io.BytesIO() 

108 qgraph1.save(buffer) 

109 buffer.seek(0) 

110 qgraph2 = QuantumGraph.load(buffer, universe=butler.dimensions) 

111 del buffer 

112 

113 for qgraph in (qgraph1, qgraph2): 

114 self.assertEqual(len(qgraph), 5) 

115 for i, qnode in enumerate(qgraph): 

116 quantum = qnode.quantum 

117 self.assertIsNotNone(quantum.datastore_records) 

118 # only the first quantum has a pre-existing input 

119 if i == 0: 

120 datastore_name = "FileDatastore@<butlerRoot>" 

121 self.assertEqual(set(quantum.datastore_records.keys()), {datastore_name}) 

122 records_data = quantum.datastore_records[datastore_name] 

123 records = dict(records_data.records) 

124 self.assertEqual(len(records), 1) 

125 _, records = records.popitem() 

126 records = records["file_datastore_records"] 

127 self.assertEqual( 

128 [record.path for record in records], 

129 ["test/add_dataset0/add_dataset0_INSTR_det0_test.pickle"], 

130 ) 

131 else: 

132 self.assertEqual(quantum.datastore_records, {}) 

133 

134 def testResolveRefs(self): 

135 """Test for GraphBuilder with resolveRefs=True.""" 

136 

137 def _assert_resolved(refs): 

138 self.assertTrue(all(ref.id is not None for ref in refs)) 

139 

140 def _assert_unresolved(refs): 

141 self.assertTrue(all(ref.id is None for ref in refs)) 

142 

143 for resolveRefs in (False, True): 

144 with self.subTest(resolveRefs=resolveRefs): 

145 assert_refs = _assert_resolved if resolveRefs else _assert_unresolved 

146 

147 with temporaryDirectory() as root: 

148 _, qgraph = simpleQGraph.makeSimpleQGraph(root=root, resolveRefs=resolveRefs) 

149 self.assertEqual(len(qgraph), 5) 

150 

151 # check per-quantum inputs/outputs 

152 for node in qgraph: 

153 quantum = node.quantum 

154 for datasetType, refs in quantum.inputs.items(): 

155 if datasetType.name == "add_dataset0": 

156 # Existing refs are always resolved 

157 _assert_resolved(refs) 

158 else: 

159 assert_refs(refs) 

160 for datasetType, refs in quantum.outputs.items(): 

161 assert_refs(refs) 

162 

163 # check per-task init-inputs/init-outputs 

164 for taskDef in qgraph.iterTaskGraph(): 

165 if (refs := qgraph.initInputRefs(taskDef)) is not None: 

166 assert_refs(refs) 

167 if (refs := qgraph.initOutputRefs(taskDef)) is not None: 

168 assert_refs(refs) 

169 

170 # check global init-outputs 

171 assert_refs(qgraph.globalInitOutputRefs()) 

172 

173 

174if __name__ == "__main__": 174 ↛ 175line 174 didn't jump to line 175, because the condition on line 174 was never true

175 lsst.utils.tests.init() 

176 unittest.main()