Coverage for tests/test_graphBuilder.py: 20%

65 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-02 03:31 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Tests of things related to the GraphBuilder class.""" 

29 

30import io 

31import logging 

32import unittest 

33 

34import lsst.utils.tests 

35from lsst.daf.butler.registry import UserExpressionError 

36from lsst.pipe.base import QuantumGraph 

37from lsst.pipe.base.all_dimensions_quantum_graph_builder import DatasetQueryConstraintVariant 

38from lsst.pipe.base.tests import simpleQGraph 

39from lsst.utils.tests import temporaryDirectory 

40 

41_LOG = logging.getLogger(__name__) 

42 

43 

44class GraphBuilderTestCase(unittest.TestCase): 

45 """Test graph building.""" 

46 

47 def _assertGraph(self, graph: QuantumGraph) -> None: 

48 """Check basic structure of the graph.""" 

49 for taskDef in graph.iterTaskGraph(): 

50 refs = graph.initOutputRefs(taskDef) 

51 # task has one initOutput, second ref is for config dataset 

52 self.assertEqual(len(refs), 2) 

53 

54 self.assertEqual(len(list(graph.inputQuanta)), 1) 

55 

56 # This includes only "packages" dataset for now. 

57 refs = graph.globalInitOutputRefs() 

58 self.assertEqual(len(refs), 1) 

59 

60 def testDefault(self): 

61 """Simple test to verify makeSimpleQGraph can be used to make a Quantum 

62 Graph. 

63 """ 

64 with temporaryDirectory() as root: 

65 # makeSimpleQGraph calls GraphBuilder. 

66 butler, qgraph = simpleQGraph.makeSimpleQGraph(root=root) 

67 # by default makeSimpleQGraph makes a graph with 5 nodes 

68 self.assertEqual(len(qgraph), 5) 

69 self._assertGraph(qgraph) 

70 constraint = DatasetQueryConstraintVariant.OFF 

71 _, qgraph2 = simpleQGraph.makeSimpleQGraph( 

72 butler=butler, datasetQueryConstraint=constraint, callPopulateButler=False 

73 ) 

74 # When all outputs are random resolved refs, direct comparison 

75 # of graphs does not work because IDs are different. Can only 

76 # verify the number of quanta in the graph without doing something 

77 # terribly complicated. 

78 self.assertEqual(len(qgraph2), 5) 

79 constraint = DatasetQueryConstraintVariant.fromExpression("add_dataset0") 

80 _, qgraph3 = simpleQGraph.makeSimpleQGraph( 

81 butler=butler, datasetQueryConstraint=constraint, callPopulateButler=False 

82 ) 

83 self.assertEqual(len(qgraph3), 5) 

84 

85 def testAddInstrumentMismatch(self): 

86 """Verify that a RuntimeError is raised if the instrument in the user 

87 query does not match the instrument in the pipeline. 

88 """ 

89 with temporaryDirectory() as root: 

90 pipeline = simpleQGraph.makeSimplePipeline( 

91 nQuanta=5, instrument="lsst.pipe.base.tests.simpleQGraph.SimpleInstrument" 

92 ) 

93 with self.assertRaises(UserExpressionError): 

94 simpleQGraph.makeSimpleQGraph(root=root, pipeline=pipeline, userQuery="instrument = 'foo'") 

95 

96 def testUserQueryBind(self): 

97 """Verify that bind values work for user query.""" 

98 pipeline = simpleQGraph.makeSimplePipeline( 

99 nQuanta=5, instrument="lsst.pipe.base.tests.simpleQGraph.SimpleInstrument" 

100 ) 

101 instr = simpleQGraph.SimpleInstrument.getName() 

102 # With a literal in the user query 

103 with temporaryDirectory() as root: 

104 simpleQGraph.makeSimpleQGraph(root=root, pipeline=pipeline, userQuery=f"instrument = '{instr}'") 

105 # With a bind for the user query 

106 with temporaryDirectory() as root: 

107 simpleQGraph.makeSimpleQGraph( 

108 root=root, pipeline=pipeline, userQuery="instrument = instr", bind={"instr": instr} 

109 ) 

110 

111 def test_datastore_records(self): 

112 """Test for generating datastore records.""" 

113 with temporaryDirectory() as root: 

114 # need FileDatastore for this tests 

115 butler, qgraph1 = simpleQGraph.makeSimpleQGraph( 

116 root=root, inMemory=False, makeDatastoreRecords=True 

117 ) 

118 

119 # save and reload 

120 buffer = io.BytesIO() 

121 qgraph1.save(buffer) 

122 buffer.seek(0) 

123 qgraph2 = QuantumGraph.load(buffer, universe=butler.dimensions) 

124 del buffer 

125 

126 for qgraph in (qgraph1, qgraph2): 

127 self.assertEqual(len(qgraph), 5) 

128 for i, qnode in enumerate(qgraph): 

129 quantum = qnode.quantum 

130 self.assertIsNotNone(quantum.datastore_records) 

131 # only the first quantum has a pre-existing input 

132 if i == 0: 

133 datastore_name = "FileDatastore@<butlerRoot>" 

134 self.assertEqual(set(quantum.datastore_records.keys()), {datastore_name}) 

135 records_data = quantum.datastore_records[datastore_name] 

136 records = dict(records_data.records) 

137 self.assertEqual(len(records), 1) 

138 _, records = records.popitem() 

139 records = records["file_datastore_records"] 

140 self.assertEqual( 

141 [record.path for record in records], 

142 ["test/add_dataset0/add_dataset0_INSTR_det0_test.pickle"], 

143 ) 

144 else: 

145 self.assertEqual(quantum.datastore_records, {}) 

146 

147 

148if __name__ == "__main__": 

149 lsst.utils.tests.init() 

150 unittest.main()