Coverage for tests/test_transform.py: 22%
97 statements
« prev ^ index » next coverage.py v6.5.0, created at 2024-03-20 00:56 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2024-03-20 00:56 -0700
1# This file is part of ctrl_bps.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21"""Unit tests of transform.py"""
22import dataclasses
23import os
24import shutil
25import tempfile
26import unittest
28from cqg_test_utils import make_test_clustered_quantum_graph
29from lsst.ctrl.bps import BPS_SEARCH_ORDER, BpsConfig, GenericWorkflowJob
30from lsst.ctrl.bps.transform import _get_job_values, create_generic_workflow, create_generic_workflow_config
32TESTDIR = os.path.abspath(os.path.dirname(__file__))
35class TestCreateGenericWorkflowConfig(unittest.TestCase):
36 """Tests of create_generic_workflow_config."""
38 def testCreate(self):
39 """Test successful creation of the config."""
40 config = BpsConfig({"a": 1, "b": 2, "uniqProcName": "testCreate"})
41 wf_config = create_generic_workflow_config(config, "/test/create/prefix")
42 self.assertIsInstance(wf_config, BpsConfig)
43 for key in config:
44 self.assertEqual(wf_config[key], config[key])
45 self.assertEqual(wf_config["workflowName"], "testCreate")
46 self.assertEqual(wf_config["workflowPath"], "/test/create/prefix")
49class TestCreateGenericWorkflow(unittest.TestCase):
50 """Tests of create_generic_workflow."""
52 def setUp(self):
53 self.tmpdir = tempfile.mkdtemp(dir=TESTDIR)
54 self.config = BpsConfig(
55 {
56 "runInit": True,
57 "computeSite": "global",
58 "runQuantumCommand": "gexe -q {qgraphFile} --qgraph-node-id {qgraphNodeId}",
59 "clusterTemplate": "{D1}_{D2}",
60 "cluster": {
61 "cl1": {"pipetasks": "T1, T2", "dimensions": "D1, D2"},
62 "cl2": {"pipetasks": "T3, T4", "dimensions": "D1, D2"},
63 },
64 "cloud": {
65 "cloud1": {"runQuantumCommand": "c1exe -q {qgraphFile} --qgraph-node-id {qgraphNodeId}"},
66 "cloud2": {"runQuantumCommand": "c2exe -q {qgraphFile} --qgraph-node-id {qgraphNodeId}"},
67 },
68 "site": {
69 "site1": {"runQuantumCommand": "s1exe -q {qgraphFile} --qgraph-node-id {qgraphNodeId}"},
70 "site2": {"runQuantumCommand": "s2exe -q {qgraphFile} --qgraph-node-id {qgraphNodeId}"},
71 "global": {"runQuantumCommand": "s3exe -q {qgraphFile} --qgraph-node-id {qgraphNodeId}"},
72 },
73 # Needed because transform assumes they exist
74 "whenSaveJobQgraph": "NEVER",
75 "executionButler": {"whenCreate": "SUBMIT", "whenMerge": "ALWAYS"},
76 },
77 BPS_SEARCH_ORDER,
78 )
79 self.cqg = make_test_clustered_quantum_graph(self.config)
81 def tearDown(self):
82 shutil.rmtree(self.tmpdir, ignore_errors=True)
84 def testCreatingGenericWorkflowGlobal(self):
85 """Test creating a GenericWorkflow with global settings."""
86 config = BpsConfig(self.config)
87 config["computeCloud"] = "cloud1"
88 config["computeSite"] = "site2"
89 config["queue"] = "global_queue"
90 print(config)
91 workflow = create_generic_workflow(config, self.cqg, "test_gw", self.tmpdir)
92 for jname in workflow:
93 gwjob = workflow.get_job(jname)
94 print(gwjob)
95 self.assertEqual(gwjob.compute_site, "site2")
96 self.assertEqual(gwjob.compute_cloud, "cloud1")
97 self.assertEqual(gwjob.executable.src_uri, "s2exe")
98 self.assertEqual(gwjob.queue, "global_queue")
99 final = workflow.get_final()
100 self.assertEqual(final.compute_site, "site2")
101 self.assertEqual(final.compute_cloud, "cloud1")
102 self.assertEqual(final.queue, "global_queue")
104 def testCreatingQuantumGraphMixed(self):
105 """Test creating a GenericWorkflow with setting overrides."""
106 config = BpsConfig(self.config)
107 config[".cluster.cl1.computeCloud"] = "cloud2"
108 config[".cluster.cl1.computeSite"] = "notthere"
109 config[".cluster.cl2.computeSite"] = "site1"
110 config[".executionButler.queue"] = "special_final_queue"
111 config[".executionButler.computeSite"] = "special_site"
112 config[".executionButler.computeCloud"] = "special_cloud"
113 workflow = create_generic_workflow(config, self.cqg, "test_gw", self.tmpdir)
114 for jname in workflow:
115 gwjob = workflow.get_job(jname)
116 print(gwjob)
117 if jname.startswith("cl1"):
118 self.assertEqual(gwjob.compute_site, "notthere")
119 self.assertEqual(gwjob.compute_cloud, "cloud2")
120 self.assertEqual(gwjob.executable.src_uri, "c2exe")
121 elif jname.startswith("cl2"):
122 self.assertEqual(gwjob.compute_site, "site1")
123 self.assertIsNone(gwjob.compute_cloud)
124 self.assertEqual(gwjob.executable.src_uri, "s1exe")
125 elif jname.startswith("pipetask"):
126 self.assertEqual(gwjob.compute_site, "global")
127 self.assertIsNone(gwjob.compute_cloud)
128 self.assertEqual(gwjob.executable.src_uri, "s3exe")
129 final = workflow.get_final()
130 self.assertEqual(final.compute_site, "special_site")
131 self.assertEqual(final.compute_cloud, "special_cloud")
132 self.assertEqual(final.queue, "special_final_queue")
135class TestGetJobValues(unittest.TestCase):
136 """Tests of _get_job_values."""
138 def setUp(self):
139 self.default_job = GenericWorkflowJob("default_job")
141 def testGettingDefaults(self):
142 """Test retrieving default values."""
143 config = BpsConfig({})
144 job_values = _get_job_values(config, {}, None)
145 self.assertTrue(
146 all(
147 [
148 getattr(self.default_job, field.name) == job_values[field.name]
149 for field in dataclasses.fields(self.default_job)
150 ]
151 )
152 )
154 def testEnablingMemoryScaling(self):
155 """Test enabling the memory scaling mechanism."""
156 config = BpsConfig({"memoryMultiplier": 2.0})
157 job_values = _get_job_values(config, {}, None)
158 self.assertAlmostEqual(job_values["memory_multiplier"], 2.0)
159 self.assertEqual(job_values["number_of_retries"], 5)
161 def testDisablingMemoryScaling(self):
162 """Test disabling the memory scaling mechanism."""
163 config = BpsConfig({"memoryMultiplier": 0.5})
164 job_values = _get_job_values(config, {}, None)
165 self.assertIsNone(job_values["memory_multiplier"])
167 def testRetrievingCmdLine(self):
168 """Test retrieving the command line."""
169 cmd_line_key = "runQuantum"
170 config = BpsConfig({cmd_line_key: "/path/to/foo bar.txt"})
171 job_values = _get_job_values(config, {}, cmd_line_key)
172 self.assertEqual(job_values["executable"].name, "foo")
173 self.assertEqual(job_values["executable"].src_uri, "/path/to/foo")
174 self.assertEqual(job_values["arguments"], "bar.txt")
177if __name__ == "__main__": 177 ↛ 178line 177 didn't jump to line 178, because the condition on line 177 was never true
178 unittest.main()