Coverage for tests/test_dotTools.py: 25%

73 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-23 10:58 +0000

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Simple unit test for Pipeline. 

29""" 

30 

31import io 

32import re 

33import unittest 

34 

35import lsst.pipe.base.connectionTypes as cT 

36import lsst.utils.tests 

37from lsst.ctrl.mpexec.dotTools import pipeline2dot 

38from lsst.pipe.base import Pipeline, PipelineTask, PipelineTaskConfig, PipelineTaskConnections 

39 

40 

41class ExamplePipelineTaskConnections(PipelineTaskConnections, dimensions=()): 

42 """Connections class used for testing.""" 

43 

44 input1 = cT.Input( 

45 name="", dimensions=["visit", "detector"], storageClass="example", doc="Input for this task" 

46 ) 

47 input2 = cT.Input( 

48 name="", dimensions=["visit", "detector"], storageClass="example", doc="Input for this task" 

49 ) 

50 output1 = cT.Output( 

51 name="", dimensions=["visit", "detector"], storageClass="example", doc="Output for this task" 

52 ) 

53 output2 = cT.Output( 

54 name="", dimensions=["visit", "detector"], storageClass="example", doc="Output for this task" 

55 ) 

56 

57 def __init__(self, *, config=None): 

58 super().__init__(config=config) 

59 if not config.connections.input2: 

60 self.inputs.remove("input2") 

61 if not config.connections.output2: 

62 self.outputs.remove("output2") 

63 

64 

65class ExamplePipelineTaskConfig(PipelineTaskConfig, pipelineConnections=ExamplePipelineTaskConnections): 

66 """Example config used for testing.""" 

67 

68 

69def _makeConfig(inputName, outputName, pipeline, label): 

70 """Add config overrides. 

71 

72 Factory method for config instances. 

73 

74 inputName and outputName can be either string or tuple of strings 

75 with two items max. 

76 """ 

77 if isinstance(inputName, tuple): 

78 pipeline.addConfigOverride(label, "connections.input1", inputName[0]) 

79 pipeline.addConfigOverride(label, "connections.input2", inputName[1] if len(inputName) > 1 else "") 

80 else: 

81 pipeline.addConfigOverride(label, "connections.input1", inputName) 

82 

83 if isinstance(outputName, tuple): 

84 pipeline.addConfigOverride(label, "connections.output1", outputName[0]) 

85 pipeline.addConfigOverride(label, "connections.output2", outputName[1] if len(outputName) > 1 else "") 

86 else: 

87 pipeline.addConfigOverride(label, "connections.output1", outputName) 

88 

89 

90class ExamplePipelineTask(PipelineTask): 

91 """Example pipeline task used for testing.""" 

92 

93 ConfigClass = ExamplePipelineTaskConfig 

94 

95 

96def _makePipeline(tasks): 

97 """Generate Pipeline instance. 

98 

99 Parameters 

100 ---------- 

101 tasks : list of tuples 

102 Each tuple in the list has 3 or 4 items: 

103 - input DatasetType name(s), string or tuple of strings 

104 - output DatasetType name(s), string or tuple of strings 

105 - task label, string 

106 - optional task class object, can be None 

107 

108 Returns 

109 ------- 

110 Pipeline instance 

111 """ 

112 pipe = Pipeline("test pipeline") 

113 for task in tasks: 

114 inputs = task[0] 

115 outputs = task[1] 

116 label = task[2] 

117 klass = task[3] if len(task) > 3 else ExamplePipelineTask 

118 pipe.addTask(klass, label) 

119 _makeConfig(inputs, outputs, pipe, label) 

120 return list(pipe.toExpandedPipeline()) 

121 

122 

123class DotToolsTestCase(unittest.TestCase): 

124 """A test case for dotTools""" 

125 

126 def testPipeline2dot(self): 

127 """Tests for dotTools.pipeline2dot method""" 

128 pipeline = _makePipeline( 

129 [ 

130 ("A", ("B", "C"), "task0"), 

131 ("C", "E", "task1"), 

132 ("B", "D", "task2"), 

133 (("D", "E"), "F", "task3"), 

134 ("D.C", "G", "task4"), 

135 ("task3_metadata", "H", "task5"), 

136 ] 

137 ) 

138 file = io.StringIO() 

139 pipeline2dot(pipeline, file) 

140 

141 # It's hard to validate complete output, just checking few basic 

142 # things, even that is not terribly stable. 

143 lines = file.getvalue().strip().split("\n") 

144 nglobals = 3 

145 ndatasets = 10 

146 ntasks = 6 

147 nedges = 16 

148 nextra = 2 # graph header and closing 

149 self.assertEqual(len(lines), nglobals + ndatasets + ntasks + nedges + nextra) 

150 

151 # make sure that all node names are quoted 

152 nodeRe = re.compile(r"^([^ ]+) \[.+\];$") 

153 edgeRe = re.compile(r"^([^ ]+) *-> *([^ ]+);$") 

154 for line in lines: 

155 match = nodeRe.match(line) 

156 if match: 

157 node = match.group(1) 

158 if node not in ["graph", "node", "edge"]: 

159 self.assertEqual(node[0] + node[-1], '""') 

160 continue 

161 match = edgeRe.match(line) 

162 if match: 

163 for group in (1, 2): 

164 node = match.group(group) 

165 self.assertEqual(node[0] + node[-1], '""') 

166 continue 

167 

168 # make sure components are connected appropriately 

169 self.assertIn('"D" -> "D.C"', file.getvalue()) 

170 

171 # make sure there is a connection created for metadata if someone 

172 # tries to read it in 

173 self.assertIn('"task3" -> "task3_metadata"', file.getvalue()) 

174 

175 

176class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

177 """Generic file handle leak check.""" 

178 

179 

180def setup_module(module): 

181 """Set up the module for pytest.""" 

182 lsst.utils.tests.init() 

183 

184 

185if __name__ == "__main__": 

186 lsst.utils.tests.init() 

187 unittest.main()