Coverage for tests/test_graphBuilder.py: 20%

65 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-30 10:01 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Tests of things related to the GraphBuilder class.""" 

29 

30import io 

31import logging 

32import unittest 

33 

34import lsst.utils.tests 

35from lsst.daf.butler.registry import UserExpressionError 

36from lsst.pipe.base import QuantumGraph 

37from lsst.pipe.base.graphBuilder import DatasetQueryConstraintVariant 

38from lsst.pipe.base.tests import simpleQGraph 

39from lsst.utils.tests import temporaryDirectory 

40 

41_LOG = logging.getLogger(__name__) 

42 

43 

44class GraphBuilderTestCase(unittest.TestCase): 

45 """Test graph building.""" 

46 

47 def _assertGraph(self, graph: QuantumGraph) -> None: 

48 """Check basic structure of the graph.""" 

49 for taskDef in graph.iterTaskGraph(): 

50 refs = graph.initOutputRefs(taskDef) 

51 # task has one initOutput, second ref is for config dataset 

52 self.assertEqual(len(refs), 2) 

53 

54 self.assertEqual(len(list(graph.inputQuanta)), 1) 

55 

56 # This includes only "packages" dataset for now. 

57 refs = graph.globalInitOutputRefs() 

58 self.assertEqual(len(refs), 1) 

59 

60 def testDefault(self): 

61 """Simple test to verify makeSimpleQGraph can be used to make a Quantum 

62 Graph. 

63 """ 

64 with temporaryDirectory() as root: 

65 # makeSimpleQGraph calls GraphBuilder. 

66 butler, qgraph = simpleQGraph.makeSimpleQGraph(root=root) 

67 # by default makeSimpleQGraph makes a graph with 5 nodes 

68 self.assertEqual(len(qgraph), 5) 

69 self._assertGraph(qgraph) 

70 constraint = DatasetQueryConstraintVariant.OFF 

71 _, qgraph2 = simpleQGraph.makeSimpleQGraph( 

72 butler=butler, datasetQueryConstraint=constraint, callPopulateButler=False 

73 ) 

74 # When all outputs are random resolved refs, direct comparison 

75 # of graphs does not work because IDs are different. Can only 

76 # verify the number of quanta in the graph without doing something 

77 # terribly complicated. 

78 self.assertEqual(len(qgraph2), 5) 

79 constraint = DatasetQueryConstraintVariant.fromExpression("add_dataset0") 

80 _, qgraph3 = simpleQGraph.makeSimpleQGraph( 

81 butler=butler, datasetQueryConstraint=constraint, callPopulateButler=False 

82 ) 

83 self.assertEqual(len(qgraph3), 5) 

84 

85 def testAddInstrumentMismatch(self): 

86 """Verify that a RuntimeError is raised if the instrument in the user 

87 query does not match the instrument in the pipeline. 

88 """ 

89 with temporaryDirectory() as root: 

90 pipeline = simpleQGraph.makeSimplePipeline( 

91 nQuanta=5, instrument="lsst.pipe.base.tests.simpleQGraph.SimpleInstrument" 

92 ) 

93 with self.assertRaises(UserExpressionError): 

94 simpleQGraph.makeSimpleQGraph(root=root, pipeline=pipeline, userQuery="instrument = 'foo'") 

95 

96 def testUserQueryBind(self): 

97 """Verify that bind values work for user query.""" 

98 pipeline = simpleQGraph.makeSimplePipeline( 

99 nQuanta=5, instrument="lsst.pipe.base.tests.simpleQGraph.SimpleInstrument" 

100 ) 

101 instr = simpleQGraph.SimpleInstrument.getName() 

102 # With a literal in the user query 

103 with temporaryDirectory() as root: 

104 simpleQGraph.makeSimpleQGraph(root=root, pipeline=pipeline, userQuery=f"instrument = '{instr}'") 

105 # With a bind for the user query 

106 with temporaryDirectory() as root: 

107 simpleQGraph.makeSimpleQGraph( 

108 root=root, pipeline=pipeline, userQuery="instrument = instr", bind={"instr": instr} 

109 ) 

110 

111 def test_datastore_records(self): 

112 """Test for generating datastore records.""" 

113 with temporaryDirectory() as root: 

114 # need FileDatastore for this tests 

115 butler, qgraph1 = simpleQGraph.makeSimpleQGraph( 

116 root=root, inMemory=False, makeDatastoreRecords=True 

117 ) 

118 

119 # save and reload 

120 buffer = io.BytesIO() 

121 qgraph1.save(buffer) 

122 buffer.seek(0) 

123 qgraph2 = QuantumGraph.load(buffer, universe=butler.dimensions) 

124 del buffer 

125 

126 for qgraph in (qgraph1, qgraph2): 

127 self.assertEqual(len(qgraph), 5) 

128 for i, qnode in enumerate(qgraph): 

129 quantum = qnode.quantum 

130 self.assertIsNotNone(quantum.datastore_records) 

131 # only the first quantum has a pre-existing input 

132 if i == 0: 

133 datastore_name = "FileDatastore@<butlerRoot>" 

134 self.assertEqual(set(quantum.datastore_records.keys()), {datastore_name}) 

135 records_data = quantum.datastore_records[datastore_name] 

136 records = dict(records_data.records) 

137 self.assertEqual(len(records), 1) 

138 _, records = records.popitem() 

139 records = records["file_datastore_records"] 

140 self.assertEqual( 

141 [record.path for record in records], 

142 ["test/add_dataset0/add_dataset0_INSTR_det0_test.pickle"], 

143 ) 

144 else: 

145 self.assertEqual(quantum.datastore_records, {}) 

146 

147 

148if __name__ == "__main__": 

149 lsst.utils.tests.init() 

150 unittest.main()